package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.cost.PartialAggregationStatsEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.class */
public class TestPushPartialAggregationThroughExchange extends BaseRuleTest {
    public TestPushPartialAggregationThroughExchange() {
        super(new Plugin[0]);
    }

    @Test
    public void testPartialAggregationAdded() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                    exchangeBuilder.addSource(planBuilder.values(variable)).addInputsSet(variable).singleDistributionPartitioningScheme(variable);
                })).addAggregation(planBuilder.variable("SUM", DoubleType.DOUBLE), planBuilder.rowExpression("SUM(a)")).globalGrouping().step(AggregationNode.Step.PARTIAL);
            });
        }).matches(PlanMatchPattern.exchange(PlanMatchPattern.project(PlanMatchPattern.aggregation(ImmutableMap.of("SUM", PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("a")))));
    }

    @Test
    public void testNoPartialAggregationWhenDisabled() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "NEVER").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                    exchangeBuilder.addSource(planBuilder.values(variable)).addInputsSet(variable).singleDistributionPartitioningScheme(variable);
                })).addAggregation(planBuilder.variable("SUM", DoubleType.DOUBLE), planBuilder.rowExpression("SUM(a)")).globalGrouping().step(AggregationNode.Step.PARTIAL);
            });
        }).doesNotFire();
    }

    @Test
    public void testNoPartialAggregationWhenReductionBelowThreshold() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a", DoubleType.DOUBLE);
            VariableReferenceExpression variable2 = planBuilder.variable("b", DoubleType.DOUBLE);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                    exchangeBuilder.addSource(planBuilder.values(new PlanNodeId("values"), variable, variable2)).addInputsSet(variable, variable2).singleDistributionPartitioningScheme(variable, variable2);
                })).addAggregation(planBuilder.variable("SUM", DoubleType.DOUBLE), planBuilder.rowExpression("SUM(a)")).singleGroupingSet(variable2).step(AggregationNode.Step.SINGLE);
            });
        }).overrideStats("values", PlanNodeStatsEstimate.builder().setOutputRowCount(1000.0d).addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(true).build()).doesNotFire();
    }

    @Test
    public void testNoPartialAggregationWhenReductionBelowThresholdUsingPartialAggregationStats() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").setSystemProperty("use_partial_aggregation_history", "true").on(planBuilder -> {
            return constructAggregation(planBuilder);
        }).overrideStats("aggregation", PlanNodeStatsEstimate.builder().addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(true).setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000.0d, 800.0d, 10.0d, 10.0d)).build()).doesNotFire();
    }

    @Test
    public void testNoPartialAggregationWhenReductionAboveThresholdUsingPartialAggregationStats() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").setSystemProperty("use_partial_aggregation_history", "true").on(planBuilder -> {
            return constructAggregation(planBuilder);
        }).overrideStats("aggregation", PlanNodeStatsEstimate.builder().addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(true).setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000.0d, 300.0d, 10.0d, 10.0d)).build()).doesNotFire();
    }

    @Test
    public void testNoPartialAggregationWhenRowReductionBelowThreshold() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").setSystemProperty("use_partial_aggregation_history", "true").on(planBuilder -> {
            return constructAggregation(planBuilder);
        }).overrideStats("aggregation", PlanNodeStatsEstimate.builder().addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(true).setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0.0d, 300.0d, 10.0d, 8.0d)).build()).doesNotFire();
    }

    @Test
    public void testPartialAggregationWhenRowReductionAboveThreshold() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").setSystemProperty("use_partial_aggregation_history", "true").on(planBuilder -> {
            return constructAggregation(planBuilder);
        }).overrideStats("aggregation", PlanNodeStatsEstimate.builder().addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(true).setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0.0d, 300.0d, 10.0d, 1.0d)).build()).matches(PlanMatchPattern.aggregation(ImmutableMap.of("sum", PlanMatchPattern.functionCall("sum", ImmutableList.of("sum0"))), PlanMatchPattern.aggregation(ImmutableMap.of("sum0", PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), PlanMatchPattern.exchange(PlanMatchPattern.values("a", "b")))));
    }

    @Test
    public void testPartialAggregationEnabledWhenNotConfident() {
        tester().assertThat((Rule) new PushPartialAggregationThroughExchange(getFunctionManager())).setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a", DoubleType.DOUBLE);
            VariableReferenceExpression variable2 = planBuilder.variable("b", DoubleType.DOUBLE);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                    exchangeBuilder.addSource(planBuilder.values(new PlanNodeId("values"), variable, variable2)).addInputsSet(variable, variable2).singleDistributionPartitioningScheme(variable, variable2);
                })).addAggregation(planBuilder.variable("SUM", DoubleType.DOUBLE), planBuilder.rowExpression("SUM(a)")).singleGroupingSet(variable2).step(AggregationNode.Step.PARTIAL);
            });
        }).overrideStats("values", PlanNodeStatsEstimate.builder().setOutputRowCount(1000.0d).addVariableStatistics(Expressions.variable("b", DoubleType.DOUBLE), new VariableStatsEstimate(0.0d, 100.0d, 0.0d, 8.0d, 800.0d)).setConfident(false).build()).matches(PlanMatchPattern.exchange(PlanMatchPattern.project(PlanMatchPattern.aggregation(ImmutableMap.of("SUM", PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("a", "b")))));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static AggregationNode constructAggregation(PlanBuilder planBuilder) {
        VariableReferenceExpression variable = planBuilder.variable("a", DoubleType.DOUBLE);
        VariableReferenceExpression variable2 = planBuilder.variable("b", DoubleType.DOUBLE);
        return planBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(new PlanNodeId("values"), variable, variable2)).addInputsSet(variable, variable2).singleDistributionPartitioningScheme((List<VariableReferenceExpression>) ImmutableList.of(variable, variable2));
            })).addAggregation(planBuilder.variable("sum", DoubleType.DOUBLE), planBuilder.rowExpression("sum(a)")).singleGroupingSet(variable2).setPlanNodeId(new PlanNodeId("aggregation"));
        });
    }
}
