package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.class */
public class TestPushInequalityFilterExpressionBelowJoinRuleSet extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private PushInequalityFilterExpressionBelowJoinRuleSet ruleSet;

    public TestPushInequalityFilterExpressionBelowJoinRuleSet() {
        super(new Plugin[0]);
    }

    @BeforeAll
    public void setUpBeforeClass() {
        this.ruleSet = new PushInequalityFilterExpressionBelowJoinRuleSet();
    }

    @Test
    public void testExpressionNotPushedDownToLeftJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(Comparison.Operator.LESS_THAN, add(symbol, 1L), symbol2.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    @Test
    public void testJoinFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.filter(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "expr"), new Reference(BigintType.BIGINT, "a"))).left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L))))), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testManyJoinFilterExpressionsPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) Logical.and(comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), comparison(Comparison.Operator.GREATER_THAN, add(symbol2, 10L), symbol.toSymbolReference())), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "expr_less"), new Reference(BigintType.BIGINT, "a")), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "expr_greater"), new Reference(BigintType.BIGINT, "a"))))).left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr_less", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L)))), "expr_greater", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 10L))))), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testOnlyRightJoinFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), add(symbol, 2L)), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.filter(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "expr"), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "a"), new Constant(BigintType.BIGINT, 2L))))).left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L))))), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testParentFilterExpressionNotPushedDownToLeftJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(Comparison.Operator.LESS_THAN, add(symbol, 1L), symbol2.toSymbolReference()), planBuilder.join(JoinType.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
    }

    @Test
    public void testParentFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), planBuilder.join(JoinType.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "expr"), new Reference(BigintType.BIGINT, "a")), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L))))), PlanMatchPattern.values("b")));
        }))));
    }

    @Test
    public void testManyParentFilterExpressionsPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(Logical.and(comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), comparison(Comparison.Operator.GREATER_THAN, add(symbol2, 10L), symbol.toSymbolReference())), planBuilder.join(JoinType.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "expr_less"), new Reference(BigintType.BIGINT, "a")), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "expr_greater"), new Reference(BigintType.BIGINT, "a")))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr_less", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L)))), "expr_greater", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 10L))))), PlanMatchPattern.values("b")));
        }))));
    }

    @Test
    public void testOnlyParentFilterExpressionExposedInaJoin() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(Comparison.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(Comparison.Operator.LESS_THAN, add(symbol2, 2L), symbol.toSymbolReference()), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "parent_expression"), new Reference(BigintType.BIGINT, "a")), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.filter(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "join_expression"), new Reference(BigintType.BIGINT, "a"))).left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("join_expression", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 2L)))), "parent_expression", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 1L))))), PlanMatchPattern.values("b")));
        }).withExactOutputs("a", "b", "parent_expression"))));
    }

    @Test
    public void testNoExpression() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(Comparison.Operator.LESS_THAN, symbol.toSymbolReference(), symbol2.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    @Test
    public void testNotSupportedExpression() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinType.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), IrExpressions.not(FUNCTIONS.getMetadata(), comparison(Comparison.Operator.IDENTICAL, symbol.toSymbolReference(), symbol2.toSymbolReference())), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    private static Comparison comparison(Comparison.Operator operator, Expression expression, Expression expression2) {
        return new Comparison(operator, expression, expression2);
    }

    private Call add(Symbol symbol, long j) {
        return new Call(ADD_BIGINT, ImmutableList.of(symbol.toSymbolReference(), new Constant(BigintType.BIGINT, Long.valueOf(j))));
    }
}
