package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.testing.Closeables;
import com.facebook.presto.Session;
import com.facebook.presto.SessionTestUtils;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.CachingCostProvider;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionNodeInliner;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.LogicalPropertiesProvider;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.RowExpressionVerifier;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.TestingSession;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.io.Closeable;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.class */
public class TestJoinEnumerator {
    private LocalQueryRunner queryRunner;
    private Metadata metadata;
    private DeterminismEvaluator determinismEvaluator;
    private FunctionResolution functionResolution;
    private PlanBuilder planBuilder;
    private TestingRowExpressionTranslator rowExpressionTranslator;
    private Session session;

    @BeforeClass
    public void setUp() {
        this.session = TestingSession.testSessionBuilder().build();
        this.queryRunner = new LocalQueryRunner(this.session);
        this.metadata = this.queryRunner.getMetadata();
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(this.metadata);
        this.functionResolution = new FunctionResolution(this.metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
        this.planBuilder = new PlanBuilder(this.session, new PlanNodeIdAllocator(), this.metadata);
        this.rowExpressionTranslator = new TestingRowExpressionTranslator(this.metadata);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown() {
        Closeables.closeAllRuntimeException(new Closeable[]{this.queryRunner});
        this.queryRunner = null;
    }

    @Test
    public void testGeneratePartitions() {
        Assert.assertEquals(ReorderJoins.JoinEnumerator.generatePartitions(4), ImmutableSet.of(ImmutableSet.of(0), ImmutableSet.of(0, 1), ImmutableSet.of(0, 2), ImmutableSet.of(0, 3), ImmutableSet.of(0, 1, 2), ImmutableSet.of(0, 1, 3), new ImmutableSet[]{ImmutableSet.of(0, 2, 3)}));
        Assert.assertEquals(ReorderJoins.JoinEnumerator.generatePartitions(3), ImmutableSet.of(ImmutableSet.of(0), ImmutableSet.of(0, 1), ImmutableSet.of(0, 2)));
    }

    @Test
    public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() {
        PlanBuilder planBuilder = new PlanBuilder(SessionTestUtils.TEST_SESSION, new PlanNodeIdAllocator(), this.queryRunner.getMetadata());
        VariableReferenceExpression variable = planBuilder.variable("A1");
        VariableReferenceExpression variable2 = planBuilder.variable("B1");
        ReorderJoins.MultiJoinNode multiJoinNode = new ReorderJoins.MultiJoinNode(new LinkedHashSet((Collection) ImmutableList.of(planBuilder.values(variable), planBuilder.values(variable2))), LogicalRowExpressions.TRUE_CONSTANT, ImmutableList.of(variable, variable2), Assignments.of());
        ReorderJoins.JoinEnumerationResult createJoinAccordingToPartitioning = new ReorderJoins.JoinEnumerator(new CostComparator(1.0d, 1.0d, 1.0d), multiJoinNode.getFilter(), createContext(), this.determinismEvaluator, this.functionResolution, this.metadata).createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputVariables(), ImmutableSet.of(0));
        Assert.assertFalse(createJoinAccordingToPartitioning.getPlanNode().isPresent());
        Assert.assertEquals(createJoinAccordingToPartitioning.getCost(), PlanCostEstimate.infinite());
    }

    @Test
    public void testJoinClauseAndFilterInference() {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        builder.put("a", BigintType.BIGINT);
        builder.put("b", BigintType.BIGINT);
        builder.put("c", BigintType.BIGINT);
        builder.put("d", BigintType.BIGINT);
        ImmutableMap build = builder.build();
        VariableReferenceExpression variable = Expressions.variable("a", build.get("a"));
        VariableReferenceExpression variable2 = Expressions.variable("b", build.get("b"));
        VariableReferenceExpression variable3 = Expressions.variable("c", build.get("c"));
        VariableReferenceExpression variable4 = Expressions.variable("d", build.get("d"));
        SymbolAliases.Builder builder2 = SymbolAliases.builder();
        builder2.put("A", new SymbolReference("a"));
        builder2.put("B", new SymbolReference("b"));
        builder2.put("C", new SymbolReference("c"));
        builder2.put("D", new SymbolReference("d"));
        SymbolAliases build2 = builder2.build();
        assertJoinCondition(build2, toRowExpressionList(build, "a = b"), ImmutableSet.of(variable), ImmutableSet.of(variable2, variable3), "A = B", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a = b", "c = d"), ImmutableSet.of(variable, variable3), ImmutableSet.of(variable2, variable4), "A = B AND C = D", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a = b + c"), ImmutableSet.of(variable), ImmutableSet.of(variable2, variable3), "A = B + C", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a = b + c"), ImmutableSet.of(variable2, variable3), ImmutableSet.of(variable), "A = B + C", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a = b + c + 1"), ImmutableSet.of(variable), ImmutableSet.of(variable2, variable3), "A = B + C + 1", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a = b + c + 1"), ImmutableSet.of(variable2, variable3), ImmutableSet.of(variable), "A = B + C + 1", null);
        assertJoinCondition(build2, toRowExpressionList(build, "a + b = c"), ImmutableSet.of(variable), ImmutableSet.of(variable2, variable3), null, "A + B = C");
        assertJoinCondition(build2, toRowExpressionList(build, "a + b = 1"), ImmutableSet.of(variable), ImmutableSet.of(variable2), null, "A + B = 1");
        assertJoinCondition(build2, toRowExpressionList(build, "a = ABS(b)", "a = ceil(b-c)", "b = c + 10"), ImmutableSet.of(variable), ImmutableSet.of(variable2, variable3), "A = abs(B) AND A = ceil(B-C)", "B = C + 10");
    }

    private List<RowExpression> toRowExpressionList(Map<String, Type> map, String... strArr) {
        return (List) Arrays.stream(strArr).map(str -> {
            return this.rowExpressionTranslator.translate(str, (Map<String, Type>) map);
        }).collect(Collectors.toList());
    }

    private void assertJoinCondition(SymbolAliases symbolAliases, List<RowExpression> list, Set<VariableReferenceExpression> set, Set<VariableReferenceExpression> set2, String str, String str2) {
        RowExpressionVerifier rowExpressionVerifier = new RowExpressionVerifier(symbolAliases, this.metadata, this.session);
        ReorderJoins.JoinEnumerator.JoinCondition extractJoinConditions = new ReorderJoins.JoinEnumerator(new CostComparator(1.0d, 1.0d, 1.0d), LogicalRowExpressions.TRUE_CONSTANT, createContext(), this.determinismEvaluator, this.functionResolution, this.metadata).extractJoinConditions(list, set, set2, new VariableAllocator());
        Optional reduce = extractJoinConditions.getJoinClauses().stream().map(equiJoinClause -> {
            return JoinNodeUtils.toRowExpression(equiJoinClause, this.functionResolution);
        }).map(rowExpression -> {
            return RowExpressionNodeInliner.replaceExpression(rowExpression, extractJoinConditions.getNewLeftAssignments());
        }).map(rowExpression2 -> {
            return RowExpressionNodeInliner.replaceExpression(rowExpression2, extractJoinConditions.getNewRightAssignments());
        }).reduce((rowExpression3, rowExpression4) -> {
            return LogicalRowExpressions.and(new RowExpression[]{rowExpression3, rowExpression4});
        });
        if (reduce.isPresent()) {
            Assert.assertNotNull(str);
            Assert.assertTrue(((Boolean) rowExpressionVerifier.process((Node) PlanBuilder.expression(str), reduce.get())).booleanValue());
        } else {
            Assert.assertNull(str);
        }
        Optional reduce2 = extractJoinConditions.getJoinFilters().stream().reduce((rowExpression5, rowExpression6) -> {
            return LogicalRowExpressions.and(new RowExpression[]{rowExpression5, rowExpression6});
        });
        if (!reduce2.isPresent()) {
            Assert.assertNull(str2);
        } else {
            Assert.assertNotNull(str2);
            Assert.assertTrue(((Boolean) rowExpressionVerifier.process((Node) PlanBuilder.expression(str2), reduce2.get())).booleanValue());
        }
    }

    private Rule.Context createContext() {
        final PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
        final VariableAllocator variableAllocator = new VariableAllocator();
        final CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(this.queryRunner.getStatsCalculator(), Optional.empty(), Lookup.noLookup(), this.queryRunner.getDefaultSession(), TypeProvider.viewOf(variableAllocator.getVariables()));
        final CachingCostProvider cachingCostProvider = new CachingCostProvider(this.queryRunner.getCostCalculator(), cachingStatsProvider, Optional.empty(), this.queryRunner.getDefaultSession());
        return new Rule.Context() { // from class: com.facebook.presto.sql.planner.iterative.rule.TestJoinEnumerator.1
            public Lookup getLookup() {
                return Lookup.noLookup();
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return planNodeIdAllocator;
            }

            public VariableAllocator getVariableAllocator() {
                return variableAllocator;
            }

            public Session getSession() {
                return TestJoinEnumerator.this.queryRunner.getDefaultSession();
            }

            public StatsProvider getStatsProvider() {
                return cachingStatsProvider;
            }

            public CostProvider getCostProvider() {
                return cachingCostProvider;
            }

            public void checkTimeoutNotExhausted() {
            }

            public WarningCollector getWarningCollector() {
                return WarningCollector.NOOP;
            }

            public Optional<LogicalPropertiesProvider> getLogicalPropertiesProvider() {
                return Optional.empty();
            }
        };
    }
}
