package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsAndCosts;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.planner.assertions.PlanAssert;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.TestingSession;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.class */
public class TestPushProjectionThroughJoin {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BigintType.BIGINT));

    @Test
    public void testPushesProjectionThroughJoin() {
        PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
        PlanBuilder planBuilder = new PlanBuilder(planNodeIdAllocator, TestingPlannerContext.PLANNER_CONTEXT, SessionTestUtils.TEST_SESSION);
        Symbol symbol = planBuilder.symbol("a0");
        Symbol symbol2 = planBuilder.symbol("a1");
        Symbol symbol3 = planBuilder.symbol("a2");
        Symbol symbol4 = planBuilder.symbol("a3");
        Symbol symbol5 = planBuilder.symbol("b0");
        Symbol symbol6 = planBuilder.symbol("b1");
        ProjectNode project = planBuilder.project(Assignments.of(symbol4, new Call(NEGATION_BIGINT, ImmutableList.of(symbol3.toSymbolReference())), planBuilder.symbol("b2"), new Call(NEGATION_BIGINT, ImmutableList.of(symbol6.toSymbolReference()))), planBuilder.join(JoinType.INNER, planBuilder.project(Assignments.of(symbol3, new Call(NEGATION_BIGINT, ImmutableList.of(symbol.toSymbolReference())), symbol2, symbol2.toSymbolReference()), planBuilder.project(Assignments.builder().putIdentity(symbol).putIdentity(symbol2).build(), planBuilder.values(symbol, symbol2))), planBuilder.values(symbol5, symbol6), new JoinNode.EquiJoinClause(symbol2, symbol6)));
        Session build = TestingSession.testSessionBuilder().build();
        Optional pushProjectionThroughJoin = PushProjectionThroughJoin.pushProjectionThroughJoin(project, Lookup.noLookup(), planNodeIdAllocator);
        Assertions.assertThat(pushProjectionThroughJoin).isPresent();
        PlanAssert.assertPlan(build, AbstractMockMetadata.dummyMetadata(), FunctionManager.createTestingFunctionManager(), planNode -> {
            return PlanNodeStatsEstimate.unknown();
        }, new Plan((PlanNode) pushProjectionThroughJoin.get(), StatsAndCosts.empty()), Lookup.noLookup(), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria(ImmutableList.of(symbolAliases -> {
                return new JoinNode.EquiJoinClause(new Symbol(BigintType.BIGINT, "a1"), new Symbol(BigintType.BIGINT, "b1"));
            })).left(PlanMatchPattern.strictProject(ImmutableMap.of("a3", PlanMatchPattern.expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "a0")))))), "a1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "a1"))), PlanMatchPattern.strictProject(ImmutableMap.of("a0", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "a0")), "a1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "a1"))), PlanMatchPattern.values("a0", "a1")))).right(PlanMatchPattern.strictProject(ImmutableMap.of("b2", PlanMatchPattern.expression(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b1")))), "b1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "b1"))), PlanMatchPattern.values("b0", "b1")));
        }).withExactOutputs("a3", "b2"));
    }

    @Test
    public void testDoesNotPushStraddlingProjection() {
        PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), TestingPlannerContext.PLANNER_CONTEXT, SessionTestUtils.TEST_SESSION);
        Symbol symbol = planBuilder.symbol("a");
        Symbol symbol2 = planBuilder.symbol("b");
        Assertions.assertThat(PushProjectionThroughJoin.pushProjectionThroughJoin(planBuilder.project(Assignments.of(planBuilder.symbol("c"), new Call(ADD_BIGINT, ImmutableList.of(symbol.toSymbolReference(), symbol2.toSymbolReference()))), planBuilder.join(JoinType.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0])), Lookup.noLookup(), new PlanNodeIdAllocator())).isEmpty();
    }

    @Test
    public void testDoesNotPushProjectionThroughOuterJoin() {
        PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), TestingPlannerContext.PLANNER_CONTEXT, SessionTestUtils.TEST_SESSION);
        Symbol symbol = planBuilder.symbol("a");
        Assertions.assertThat(PushProjectionThroughJoin.pushProjectionThroughJoin(planBuilder.project(Assignments.of(planBuilder.symbol("c"), new Call(NEGATION_BIGINT, ImmutableList.of(symbol.toSymbolReference()))), planBuilder.join(JoinType.LEFT, planBuilder.values(symbol), planBuilder.values(planBuilder.symbol("b")), new JoinNode.EquiJoinClause[0])), Lookup.noLookup(), new PlanNodeIdAllocator())).isEmpty();
    }
}
