package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestPullConstantsAboveGroupBy.class */
public class TestPullConstantsAboveGroupBy extends BaseRuleTest {
    public TestPullConstantsAboveGroupBy() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoConstGroupingKeysDoesNotFire() {
        tester().assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL"), planBuilder.rowExpression("1")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).singleGroupingSet(planBuilder.variable("COL"));
            });
        }).doesNotFire();
    }

    @Test
    public void testMultipleGroupingSetsDoesNotFire() {
        tester().assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL"), planBuilder.rowExpression("1")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).groupingSets(AggregationNode.groupingSets(ImmutableList.of(planBuilder.variable("COL")), 2, ImmutableSet.of(0)));
            });
        }).doesNotFire();
    }

    @Test
    public void testRuleDisabledDoesNotFire() {
        new RuleTester(ImmutableList.of(), ImmutableMap.of("optimize_constant_grouping_keys", "false", "rewrite_expression_with_constant_expression", "false")).assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL"), planBuilder.rowExpression("1")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).singleGroupingSet(planBuilder.variable("CONST_COL"), planBuilder.variable("COL"));
            });
        }).doesNotFire();
    }

    @Test
    public void testOnlyConstantKeysDoesNotFire() {
        tester().assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL"), planBuilder.rowExpression("1")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).singleGroupingSet(planBuilder.variable("CONST_COL"));
            });
        }).doesNotFire();
    }

    @Test
    public void testSingleConstColumn() {
        tester().assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL"), planBuilder.rowExpression("1")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).singleGroupingSet(planBuilder.variable("CONST_COL"), planBuilder.variable("COL"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("CONST_COL", PlanMatchPattern.expression("1")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL"), ImmutableMap.builder().put(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("col"))).build(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.of("CONST_COL", PlanMatchPattern.expression("1")), PlanMatchPattern.values("COL")))));
    }

    @Test
    public void testMultipleConstCols() {
        tester().assertThat((Rule) new PullConstantsAboveGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.project(Assignments.builder().put(planBuilder.variable("COL"), planBuilder.rowExpression("COL")).put(planBuilder.variable("CONST_COL1"), planBuilder.rowExpression("1")).put(planBuilder.variable("CONST_COL2"), planBuilder.rowExpression("2")).build(), (PlanNode) planBuilder.values(planBuilder.variable("COL")))).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL)")).singleGroupingSet(planBuilder.variable("CONST_COL1"), planBuilder.variable("COL"), planBuilder.variable("CONST_COL2"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("CONST_COL1", PlanMatchPattern.expression("1"), "CONST_COL2", PlanMatchPattern.expression("2")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL"), ImmutableMap.builder().put(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("col"))).build(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.of("CONST_COL1", PlanMatchPattern.expression("1"), "CONST_COL2", PlanMatchPattern.expression("2")), PlanMatchPattern.values("COL")))));
    }
}
