package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinType;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.class */
public class TestTransformCorrelatedGlobalAggregationWithProjection extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER));
    private static final ResolvedFunction SUBTRACT_INTEGER = FUNCTIONS.resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER));

    public TestTransformCorrelatedGlobalAggregationWithProjection() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireOnPlanWithoutCorrelatedJoinNode() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.values(planBuilder.symbol("a"));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnCorrelatedWithoutAggregation() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("a")));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnUncorrelated() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(), planBuilder.values(planBuilder.symbol("a")), planBuilder.values(planBuilder.symbol("b")));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnCorrelatedWithNonScalarAggregation() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("b"));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnMultipleProjections() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr_2", IntegerType.INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "expr"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.project(Assignments.of(planBuilder.symbol("expr", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            }))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnSubqueryWithoutProjection() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            }));
        }).doesNotFire();
    }

    @Test
    public void testRewritesOnSubqueryWithProjection() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "corr")), "expr", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum_1"), new Constant(IntegerType.INTEGER, 1L))))), PlanMatchPattern.aggregation(ImmutableMap.of("sum_1", PlanMatchPattern.aggregationFunction("sum", ImmutableList.of("a"))), PlanMatchPattern.join(JoinType.LEFT, builder -> {
            builder.left(PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0)))).right(PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression(Booleans.TRUE)), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0, "b", 1))));
        }))));
    }

    @Test
    public void testRewritesOnSubqueryWithDistinct() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr_sum", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum"), new Constant(IntegerType.INTEGER, 1L))), planBuilder.symbol("expr_count", IntegerType.INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("count"), PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), ImmutableList.of()).globalGrouping().source(planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder.symbol("a")).source(planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "corr")), planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))));
                }));
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "corr")), "expr_sum", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum_agg"), new Constant(IntegerType.INTEGER, 1L)))), "expr_count", PlanMatchPattern.expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "count_agg"), new Constant(IntegerType.INTEGER, 1L))))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), PlanMatchPattern.aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), ImmutableList.of(), ImmutableList.of("non_null"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique", "non_null", "a"), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.join(JoinType.LEFT, builder -> {
            builder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "corr"))).left(PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values("corr"))).right(PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression(Booleans.TRUE)), PlanMatchPattern.filter(Booleans.TRUE, PlanMatchPattern.values("a", "b"))));
        })))));
    }

    @Test
    public void testRewritesOnSubqueryWithDecorrelatableDistinct() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr_sum", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum"), new Constant(IntegerType.INTEGER, 1L))), planBuilder.symbol("expr_count", IntegerType.INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("count"), PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), ImmutableList.of()).globalGrouping().source(planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder.symbol("a")).source(planBuilder.filter(new Comparison(Comparison.Operator.EQUAL, new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "corr")), planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))));
                }));
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "corr")), "expr_sum", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum_agg"), new Constant(IntegerType.INTEGER, 1L)))), "expr_count", PlanMatchPattern.expression(new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "count_agg"), new Constant(IntegerType.INTEGER, 1L))))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), PlanMatchPattern.aggregationFunction("sum", ImmutableList.of("a")), Optional.of("count_agg"), PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), ImmutableList.of(), ImmutableList.of("non_null"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.join(JoinType.LEFT, builder -> {
            builder.filter(new Comparison(Comparison.Operator.EQUAL, new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "corr"))).left(PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values("corr"))).right(PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression(Booleans.TRUE)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("a", "b"), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter(Booleans.TRUE, PlanMatchPattern.values("a", "b")))));
        }))));
    }

    @Test
    public void testWithPreexistingMask() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum"), new Constant(IntegerType.INTEGER, 1L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("mask", BooleanType.BOOLEAN))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.aggregation("sum", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "a"))), (List<Type>) ImmutableList.of(BigintType.BIGINT), planBuilder.symbol("mask", BooleanType.BOOLEAN)).globalGrouping();
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "corr")), "expr", PlanMatchPattern.expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "sum_1"), new Constant(IntegerType.INTEGER, 1L))))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("sum_1"), PlanMatchPattern.aggregationFunction("sum", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("new_mask"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.of("new_mask", PlanMatchPattern.expression(new Logical(Logical.Operator.AND, ImmutableList.of(new Reference(BooleanType.BOOLEAN, "mask"), new Reference(BooleanType.BOOLEAN, "non_null"))))), PlanMatchPattern.join(JoinType.LEFT, builder -> {
            builder.left(PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0)))).right(PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression(Booleans.TRUE)), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0, "mask", 1))));
        })))));
    }

    @Test
    public void testRewritesOnSubqueryWithBoolOr() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr", BigintType.BIGINT)), planBuilder.values(planBuilder.symbol("corr", BigintType.BIGINT)), planBuilder.project(Assignments.of(planBuilder.symbol("exists", BooleanType.BOOLEAN), new Coalesce(new Reference(BooleanType.BOOLEAN, "aggrbool"), Booleans.FALSE, new Expression[0])), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("subquery", BooleanType.BOOLEAN))).addAggregation(planBuilder.symbol("aggrbool", BooleanType.BOOLEAN), PlanBuilder.aggregation("bool_or", (List<Expression>) ImmutableList.of(new Reference(BooleanType.BOOLEAN, "subquery"))), ImmutableList.of(BooleanType.BOOLEAN)).globalGrouping();
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "corr")), "exists", PlanMatchPattern.expression(new Coalesce(new Reference(BooleanType.BOOLEAN, "aggrbool"), Booleans.FALSE, new Expression[0]))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("unique", "corr"), ImmutableMap.of(Optional.of("aggrbool"), PlanMatchPattern.aggregationFunction("bool_or", ImmutableList.of("subquery"))), ImmutableList.of(), ImmutableList.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.join(JoinType.LEFT, builder -> {
            builder.left(PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0)))).right(PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("subquery", 0)));
        }))));
    }
}
