package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Optional;
import java.util.function.Function;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestRewriteAggregationIfToFilter.class */
public class TestRewriteAggregationIfToFilter extends BaseRuleTest {
    public TestRewriteAggregationIfToFilter() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireForNonIf() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a", BooleanType.BOOLEAN);
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr"), planBuilder.rowExpression("count(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("ds > '2021-07-01'")), (PlanNode) planBuilder.values(variable2)));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForIfWithElse() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr"), planBuilder.rowExpression("count(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', 1, 2)")), (PlanNode) planBuilder.values(variable2)));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForNonDeterministicFunction() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a", DoubleType.DOUBLE);
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr"), planBuilder.rowExpression("sum(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', random())")), (PlanNode) planBuilder.values(variable2)));
            });
        }).doesNotFire();
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder2 -> {
            VariableReferenceExpression variable = planBuilder2.variable("a", BigintType.BIGINT);
            VariableReferenceExpression variable2 = planBuilder2.variable("ds", VarcharType.VARCHAR);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder2.variable("expr"), planBuilder2.rowExpression("sum(a)")).source(planBuilder2.project(PlanBuilder.assignment(variable, planBuilder2.rowExpression("IF(random() > DOUBLE '0.1', 1)")), (PlanNode) planBuilder2.values(variable2)));
            });
        }).doesNotFire();
    }

    @Test
    public void testFireCount() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr"), planBuilder.rowExpression("count(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', 1)")), (PlanNode) planBuilder.values(variable2)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr"), PlanMatchPattern.functionCall("count", ImmutableList.of("expr_0"))), ImmutableMap.of(new Symbol("expr"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)"), "greater_than", PlanMatchPattern.expression("ds > '2021-07-01'"), "expr_0", PlanMatchPattern.expression("1")), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testUnwrapIf() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr"), planBuilder.rowExpression("count(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', 1)")), (PlanNode) planBuilder.values(variable2)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr"), PlanMatchPattern.functionCall("count", ImmutableList.of("expr0"))), ImmutableMap.of(new Symbol("expr"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)"), "greater_than", PlanMatchPattern.expression("ds > '2021-07-01'"), "expr0", PlanMatchPattern.expression("1")), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testFireMin() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("MIN(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("min", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireMax() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("MAX(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("max", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireArbitrary() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("ARBITRARY(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("arbitrary", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireSum() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("SUM(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testDoesNotFireForMaxBy() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("MAX_BY(a, a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForMinBy() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("MIN_BY(a, a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
            });
        }).doesNotFire();
    }

    @Test
    public void testFireTwoAggregations() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("b");
                VariableReferenceExpression variable3 = planBuilder.variable("ds", VarcharType.VARCHAR);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("count(a)")).addAggregation(planBuilder.variable("expr1"), planBuilder.rowExpression("count(b)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', 1)"), variable2, planBuilder.rowExpression("IF(ds > '2021-06-01', 2)")), (PlanNode) planBuilder.values(variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", ImmutableList.of("expr")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", ImmutableList.of("expr_1"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than"), new Symbol("expr1"), new Symbol("greater_than_0")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than or greater_than_0", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)")).put("b", PlanMatchPattern.expression("IF(ds > '2021-06-01', 2)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-07-01'")).put("expr", PlanMatchPattern.expression("1")).put("greater_than_0", PlanMatchPattern.expression("ds > '2021-06-01'")).put("expr_1", PlanMatchPattern.expression("2")).build(), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testFireTwoAggregationsWithSharedInput() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("MIN(a)")).addAggregation(planBuilder.variable("expr1"), planBuilder.rowExpression("MAX(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode) planBuilder.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("min", ImmutableList.of("column0_0")), Optional.of("expr1"), PlanMatchPattern.functionCall("max", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than"), new Symbol("expr1"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireForOneOfTwoAggregations() {
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a");
                VariableReferenceExpression variable2 = planBuilder.variable("b");
                VariableReferenceExpression variable3 = planBuilder.variable("ds", VarcharType.VARCHAR);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("count(a)")).addAggregation(planBuilder.variable("expr1"), planBuilder.rowExpression("count(b)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(ds > '2021-07-01', 1)"), variable2, planBuilder.rowExpression("ds")), (PlanNode) planBuilder.values(variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", ImmutableList.of("expr")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", ImmutableList.of("b"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("true", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)")).put("b", PlanMatchPattern.expression("ds")).put("greater_than", PlanMatchPattern.expression("ds > '2021-07-01'")).put("expr", PlanMatchPattern.expression("1")).build(), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testArrayOffset() {
        for (String str : new String[]{"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("arrayColumn", new ArrayType(BigintType.BIGINT));
                VariableReferenceExpression variable2 = planBuilder.variable("arrayElement", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("SUM(arrayElement)")).source(planBuilder.project(PlanBuilder.assignment(variable2, planBuilder.rowExpression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])")), (PlanNode) planBuilder.values(variable)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", ImmutableList.of("arrayElement"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("arrayElement", PlanMatchPattern.expression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])")).put("greater_than", PlanMatchPattern.expression("CARDINALITY(arrayColumn) > 0")).build(), PlanMatchPattern.values("arrayColumn")))));
        }
    }

    @Test
    public void testDivide() {
        for (String str : new String[]{"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder -> {
                VariableReferenceExpression variable = planBuilder.variable("a", BigintType.BIGINT);
                VariableReferenceExpression variable2 = planBuilder.variable("b", BigintType.BIGINT);
                VariableReferenceExpression variable3 = planBuilder.variable("result", BigintType.BIGINT);
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("SUM(result)")).source(planBuilder.project(PlanBuilder.assignment(variable3, planBuilder.rowExpression("IF(b != 0, a / b)")), (PlanNode) planBuilder.values(variable, variable2)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", ImmutableList.of("result"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("not_equal")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("not_equal", PlanMatchPattern.project(new ImmutableMap.Builder().put("result", PlanMatchPattern.expression("IF(b != 0, a / b)")).put("not_equal", PlanMatchPattern.expression("b != 0")).build(), PlanMatchPattern.values("a", "b")))));
        }
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(planBuilder2 -> {
            VariableReferenceExpression variable = planBuilder2.variable("a", BigintType.BIGINT);
            VariableReferenceExpression variable2 = planBuilder2.variable("b", BigintType.BIGINT);
            VariableReferenceExpression variable3 = planBuilder2.variable("ds", VarcharType.VARCHAR);
            VariableReferenceExpression variable4 = planBuilder2.variable("result", BigintType.BIGINT);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder2.variable("expr0"), planBuilder2.rowExpression("SUM(result)")).source(planBuilder2.project(PlanBuilder.assignment(variable4, planBuilder2.rowExpression("IF(ds > '2021-07-01', a / b)")), (PlanNode) planBuilder2.values(variable3, variable, variable2)));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", ImmutableList.of("result"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("result", PlanMatchPattern.expression("a / b")).put("greater_than", PlanMatchPattern.expression("ds > '2021-07-01'")).build(), PlanMatchPattern.values("ds", "a", "b")))));
    }

    @Test
    public void testUnwrapIfForOneOfTwoAggregations() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("result0", BigintType.BIGINT);
            VariableReferenceExpression variable2 = planBuilder.variable("result1", BigintType.BIGINT);
            VariableReferenceExpression variable3 = planBuilder.variable("a", BigintType.BIGINT);
            VariableReferenceExpression variable4 = planBuilder.variable("b", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("count(result0)")).addAggregation(planBuilder.variable("expr1"), planBuilder.rowExpression("count(result1)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(b != 0, a / b)"), variable2, planBuilder.rowExpression("IF(b > 0, b)")), (PlanNode) planBuilder.values(variable3, variable4)));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", ImmutableList.of("result0")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", ImmutableList.of("b_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("not_equal"), new Symbol("expr1"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than or not_equal", PlanMatchPattern.project(new ImmutableMap.Builder().put("result0", PlanMatchPattern.expression("IF(b != 0, a / b)")).put("result1", PlanMatchPattern.expression("IF(b > 0, b)")).put("b_0", PlanMatchPattern.expression("b")).put("not_equal", PlanMatchPattern.expression("b != 0")).put("greater_than", PlanMatchPattern.expression("b > 0")).build(), PlanMatchPattern.values("a", "b")))));
    }

    @Test
    public void testRewriteStrategies() {
        Function<PlanBuilder, PlanNode> function = planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            VariableReferenceExpression variable2 = planBuilder.variable("column0", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("SUM(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("IF(column0 > 1, column0)")), (PlanNode) planBuilder.values(variable2)));
            });
        };
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "disabled").on(function).doesNotFire();
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(function).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(column0 > 1, column0)")).put("greater_than", PlanMatchPattern.expression("column0 > 1")).build(), PlanMatchPattern.values("column0")))));
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if_safe").on(function).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(column0 > 1, column0)")).put("greater_than", PlanMatchPattern.expression("column0 > 1")).build(), PlanMatchPattern.values("column0")))));
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(function).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("column0_0"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("IF(column0 > 1, column0)")).put("greater_than", PlanMatchPattern.expression("column0 > 1")).put("column0_0", PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("column0")))));
    }

    @Test
    public void testCast() {
        tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planBuilder -> {
            VariableReferenceExpression variable = planBuilder.variable("a");
            VariableReferenceExpression variable2 = planBuilder.variable("ds", VarcharType.VARCHAR);
            VariableReferenceExpression variable3 = planBuilder.variable("column0", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.variable("expr0"), planBuilder.rowExpression("SUM(a)")).source(planBuilder.project(PlanBuilder.assignment(variable, planBuilder.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")), (PlanNode) planBuilder.values(variable2, variable3)));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("CAST(IF(ds > '2021-06-01', column0) as bigint)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).build(), PlanMatchPattern.values("ds", "column0")))));
        for (String str : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat((Rule) new RewriteAggregationIfToFilter(getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", str).on(planBuilder2 -> {
                VariableReferenceExpression variable = planBuilder2.variable("a");
                VariableReferenceExpression variable2 = planBuilder2.variable("ds", VarcharType.VARCHAR);
                VariableReferenceExpression variable3 = planBuilder2.variable("column0", BigintType.BIGINT);
                return planBuilder2.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder2.variable("expr0"), planBuilder2.rowExpression("SUM(a)")).source(planBuilder2.project(PlanBuilder.assignment(variable, planBuilder2.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")), (PlanNode) planBuilder2.values(variable2, variable3)));
                });
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", ImmutableList.of("cast"))), ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project(new ImmutableMap.Builder().put("a", PlanMatchPattern.expression("CAST(IF(ds > '2021-06-01', column0) as bigint)")).put("greater_than", PlanMatchPattern.expression("ds > '2021-06-01'")).put("cast", PlanMatchPattern.expression("CAST(column0 AS bigint)")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }
}
