package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionTreeUtils;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.OrderBy;
import com.facebook.presto.sql.tree.SortItem;
import com.google.common.base.Preconditions;
import com.google.common.collect.Streams;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.class */
public class AggregationFunctionMatcher implements RvalueMatcher {
    private final ExpectedValueProvider<FunctionCall> callMaker;
    private final Optional<Symbol> mask;

    public AggregationFunctionMatcher(ExpectedValueProvider<FunctionCall> expectedValueProvider) {
        this.callMaker = (ExpectedValueProvider) Objects.requireNonNull(expectedValueProvider, "functionCall is null");
        this.mask = Optional.empty();
    }

    public AggregationFunctionMatcher(ExpectedValueProvider<FunctionCall> expectedValueProvider, Symbol symbol) {
        this.callMaker = (ExpectedValueProvider) Objects.requireNonNull(expectedValueProvider, "functionCall is null");
        this.mask = Optional.of(Objects.requireNonNull(symbol, "mask is null"));
    }

    @Override // com.facebook.presto.sql.planner.assertions.RvalueMatcher
    public Optional<VariableReferenceExpression> getAssignedVariable(PlanNode planNode, Session session, Metadata metadata, SymbolAliases symbolAliases) {
        Optional<VariableReferenceExpression> empty = Optional.empty();
        if (!(planNode instanceof AggregationNode)) {
            return empty;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        FunctionCall expectedValue = this.callMaker.getExpectedValue(symbolAliases);
        for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
            if (verifyAggregation(metadata.getFunctionAndTypeManager(), (AggregationNode.Aggregation) entry.getValue(), expectedValue, this.mask.map(symbol -> {
                return new Symbol(symbolAliases.get(symbol.getName()).getName());
            }))) {
                Preconditions.checkState(!empty.isPresent(), "Ambiguous function calls in %s", aggregationNode);
                empty = Optional.of(entry.getKey());
            }
        }
        return empty;
    }

    private static boolean verifyAggregation(FunctionAndTypeManager functionAndTypeManager, AggregationNode.Aggregation aggregation, FunctionCall functionCall, Optional<Symbol> optional) {
        return functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).getName().getObjectName().equalsIgnoreCase(functionCall.getName().getSuffix()) && aggregation.getArguments().size() == functionCall.getArguments().size() && Streams.zip(aggregation.getArguments().stream(), functionCall.getArguments().stream(), (rowExpression, expression) -> {
            return Boolean.valueOf(isEquivalent(Optional.of(expression), Optional.of(rowExpression)));
        }).allMatch((v0) -> {
            return v0.booleanValue();
        }) && isEquivalent(functionCall.getFilter(), aggregation.getFilter()) && functionCall.isDistinct() == aggregation.isDistinct() && verifyAggregationOrderBy((Optional<OrderingScheme>) aggregation.getOrderBy(), (Optional<OrderBy>) functionCall.getOrderBy()) && maskMatch(optional, aggregation.getMask());
    }

    private static boolean verifyAggregationOrderBy(Optional<OrderingScheme> optional, Optional<OrderBy> optional2) {
        return (optional.isPresent() && optional2.isPresent()) ? verifyAggregationOrderBy(optional.get(), optional2.get()) : optional.isPresent() == optional2.isPresent();
    }

    private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, OrderBy orderBy) {
        if (orderingScheme.getOrderByVariables().size() != orderBy.getSortItems().size()) {
            return false;
        }
        for (int i = 0; i < orderBy.getSortItems().size(); i++) {
            VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) orderingScheme.getOrderByVariables().get(i);
            if (!((SortItem) orderBy.getSortItems().get(i)).getSortKey().equals(ExpressionTreeUtils.createSymbolReference(variableReferenceExpression)) || !PlannerUtils.toSortOrder((SortItem) orderBy.getSortItems().get(i)).equals(orderingScheme.getOrdering(variableReferenceExpression))) {
                return false;
            }
        }
        return true;
    }

    private static boolean isEquivalent(Optional<Expression> optional, Optional<RowExpression> optional2) {
        if (!optional2.isPresent() || !optional.isPresent()) {
            return optional2.isPresent() == optional.isPresent();
        }
        Preconditions.checkArgument(optional2.get() instanceof VariableReferenceExpression, "can only process variableReference");
        return optional.get().equals(ExpressionTreeUtils.createSymbolReference(optional2.get()));
    }

    private static boolean maskMatch(Optional<Symbol> optional, Optional<VariableReferenceExpression> optional2) {
        return (optional.isPresent() && optional2.isPresent()) ? optional.get().getName().equals(optional2.get().getName()) : optional.isPresent() == optional2.isPresent();
    }

    public String toString() {
        return this.callMaker.toString();
    }
}
