package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.tree.SortItem;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.class */
public class TestMergePartialAggregationsWithFilter extends BasePlanTest {
    private Session enableOptimization() {
        return Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("merge_aggregations_with_and_without_filter", "true").setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").build();
    }

    private Session disableOptimization() {
        return Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("merge_aggregations_with_and_without_filter", "false").setSystemProperty("partial_aggregation_strategy", "AUTOMATIC").build();
    }

    @Test
    public void testOptimizationApplied() {
        assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.project(ImmutableMap.of("maskPartialSum", PlanMatchPattern.expression("IF(expr, partialSum, null)")), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey", "expr"), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))), false);
    }

    @Test
    public void testOptimizationDisabled() {
        assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey", disableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity")), Optional.of("maskPartialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity"))), ImmutableMap.of(new Symbol("maskPartialSum"), new Symbol("expr")), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))), false);
    }

    @Test
    public void testMultipleAggregations() {
        assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0), avg(quantity), avg(quantity) filter (where orderkey > 0) from lineitem group by partkey", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum")), Optional.of("finalAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("partialAvg")), Optional.of("maskFinalAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("maskPartialAvg"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.project(ImmutableMap.of("maskPartialSum", PlanMatchPattern.expression("IF(expr, partialSum, null)"), "maskPartialAvg", PlanMatchPattern.expression("IF(expr, partialAvg, null)")), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey", "expr"), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity")), Optional.of("partialAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("quantity"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))), false);
    }

    @Test
    public void testAggregationsMultipleLevel() {
        assertPlan("select partkey, avg(sum), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("finalAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("partialAvg")), Optional.of("maskFinalAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("maskPartialAvg")), Optional.of("finalFilterAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("partialFilterAvg"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.project(ImmutableMap.of("maskPartialAvg", PlanMatchPattern.expression("IF(expr_2, partialAvg, null)")), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey", "expr_2"), ImmutableMap.of(Optional.of("partialAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("finalSum")), Optional.of("partialFilterAvg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("maskFinalSum"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("expr_2", PlanMatchPattern.expression("suppkey > 0")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey", "suppkey"), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.project(ImmutableMap.of("maskPartialSum", PlanMatchPattern.expression("IF(expr, partialSum, null)")), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey", "suppkey", "expr"), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))), false);
    }

    @Test
    public void testGlobalOptimization() {
        assertPlan("SELECT sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity")), Optional.of("maskPartialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity"))), ImmutableMap.of(new Symbol("maskPartialSum"), new Symbol("expr")), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "quantity", "quantity"))))))), false);
    }

    @Test
    public void testHasOrderBy() {
        assertPlan("select partkey, array_agg(suppkey order by suppkey), array_agg(suppkey order by suppkey) filter (where orderkey > 0) from lineitem group by partkey", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("array_agg"), PlanMatchPattern.functionCall("array_agg", (List<String>) ImmutableList.of("suppkey"), (List<PlanMatchPattern.Ordering>) ImmutableList.of(PlanMatchPattern.sort("suppkey", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.LAST))), Optional.of("array_agg_filter"), PlanMatchPattern.functionCall("array_agg", (List<String>) ImmutableList.of("suppkey"), (List<PlanMatchPattern.Ordering>) ImmutableList.of(PlanMatchPattern.sort("suppkey", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.LAST)))), ImmutableMap.of(new Symbol("array_agg_filter"), new Symbol("expr")), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "suppkey", "suppkey")))))), false);
    }

    @Test
    public void testGroupingSets() {
        assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by grouping sets((), (partkey))", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("partkey$gid", "groupid"), 2, ImmutableSet.of(0)), ImmutableMap.of(Optional.of("finalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("maskPartialSum"))), ImmutableMap.of(), Optional.of(new Symbol("groupid")), AggregationNode.Step.FINAL, PlanMatchPattern.project(ImmutableMap.of("maskPartialSum", PlanMatchPattern.expression("IF(expr, partialSum, null)")), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("partkey$gid", "groupid", "expr"), 2, ImmutableSet.of(0)), ImmutableMap.of(Optional.of("partialSum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("quantity"))), ImmutableMap.of(), Optional.of(new Symbol("groupid")), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(PlanMatchPattern.groupingSet(ImmutableList.of(ImmutableList.of(), ImmutableList.of("partkey")), ImmutableMap.of("quantity", "quantity", "expr", "expr"), "groupid", ImmutableMap.of("partkey$gid", PlanMatchPattern.expression("partkey")), PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity")))))))))), false);
    }

    @Test
    public void testCalledOnNull() {
        assertPlan("SELECT partkey, count(*), count(*) filter (where orderkey > 0) from lineitem group by partkey", enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("finalCnt"), PlanMatchPattern.functionCall("count", ImmutableList.of("partialCnt")), Optional.of("maskFinalCnt"), PlanMatchPattern.functionCall("count", ImmutableList.of("maskPartialCnt"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("partkey"), ImmutableMap.of(Optional.of("partialCnt"), PlanMatchPattern.functionCall("count", ImmutableList.of()), Optional.of("maskPartialCnt"), PlanMatchPattern.functionCall("count", ImmutableList.of())), ImmutableMap.of(new Symbol("maskPartialCnt"), new Symbol("expr")), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("orderkey > 0")), PlanMatchPattern.tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey"))))))), false);
    }
}
