package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Project;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/substrait/isthmus/PreCalciteAggregateValidator.class */
public class PreCalciteAggregateValidator {

    /* loaded from: input_file:io/substrait/isthmus/PreCalciteAggregateValidator$PreCalciteAggregateTransformer.class */
    public static class PreCalciteAggregateTransformer {
        private final List<Expression> newExpressions = new ArrayList();
        private int expressionOffset;

        private PreCalciteAggregateTransformer(Aggregate aggregate) {
            this.expressionOffset = aggregate.getInput().getRecordType().fields().size();
        }

        public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) {
            PreCalciteAggregateTransformer preCalciteAggregateTransformer = new PreCalciteAggregateTransformer(aggregate);
            Stream stream = aggregate.getMeasures().stream();
            Objects.requireNonNull(preCalciteAggregateTransformer);
            List list = (List) stream.map(preCalciteAggregateTransformer::updateMeasure).collect(Collectors.toList());
            Stream stream2 = aggregate.getGroupings().stream();
            Objects.requireNonNull(preCalciteAggregateTransformer);
            return Aggregate.builder().from(aggregate).input(Project.builder().input(aggregate.getInput()).expressions(preCalciteAggregateTransformer.newExpressions).build()).measures(list).groupings((List) stream2.map(preCalciteAggregateTransformer::updateGrouping).collect(Collectors.toList())).build();
        }

        private Aggregate.Measure updateMeasure(Aggregate.Measure measure) {
            AggregateFunctionInvocation function = measure.getFunction();
            List list = (List) function.arguments().stream().map(this::projectOutNonFieldReference).collect(Collectors.toList());
            List list2 = (List) function.sort().stream().map(sortField -> {
                return Expression.SortField.builder().from(sortField).expr(projectOutNonFieldReference(sortField.expr())).build();
            }).collect(Collectors.toList());
            return Aggregate.Measure.builder().function(AggregateFunctionInvocation.builder().from(function).arguments(list).sort(list2).build()).preMeasureFilter(measure.getPreMeasureFilter().map(this::projectOutNonFieldReference)).build();
        }

        private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) {
            return Aggregate.Grouping.builder().expressions((List) grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList())).build();
        }

        private Expression projectOutNonFieldReference(FunctionArg functionArg) {
            if (functionArg instanceof Expression) {
                return projectOutNonFieldReference((Expression) functionArg);
            }
            throw new IllegalArgumentException("cannot handle non-expression argument for aggregate");
        }

        private Expression projectOutNonFieldReference(Expression expression) {
            return PreCalciteAggregateValidator.isSimpleFieldReference(expression) ? expression : projectOut(expression);
        }

        private Expression projectOut(Expression expression) {
            this.newExpressions.add(expression);
            ImmutableFieldReference.Builder builder = ImmutableFieldReference.builder();
            int i = this.expressionOffset;
            this.expressionOffset = i + 1;
            return builder.addSegments(FieldReference.StructField.of(i)).type(expression.getType()).build();
        }
    }

    public static boolean isValidCalciteAggregate(Aggregate aggregate) {
        return aggregate.getMeasures().stream().allMatch(PreCalciteAggregateValidator::isValidCalciteMeasure) && aggregate.getGroupings().stream().allMatch(PreCalciteAggregateValidator::isValidCalciteGrouping);
    }

    private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
        return measure.getFunction().arguments().stream().allMatch(functionArg -> {
            return isSimpleFieldReference(functionArg);
        }) && measure.getFunction().sort().stream().allMatch(sortField -> {
            return isSimpleFieldReference(sortField.expr());
        }) && ((Boolean) measure.getPreMeasureFilter().map(expression -> {
            return Boolean.valueOf(isSimpleFieldReference(expression));
        }).orElse(true)).booleanValue();
    }

    private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) {
        if (grouping.getExpressions().stream().allMatch(expression -> {
            return isSimpleFieldReference(expression);
        })) {
            return isOrdered((List) grouping.getExpressions().stream().map(expression2 -> {
                return Integer.valueOf(getFieldRefOffset((FieldReference) expression2));
            }).collect(Collectors.toList()));
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSimpleFieldReference(FunctionArg functionArg) {
        if (functionArg instanceof FieldReference) {
            FieldReference fieldReference = (FieldReference) functionArg;
            if (fieldReference.segments().size() == 1 && (fieldReference.segments().get(0) instanceof FieldReference.StructField)) {
                return true;
            }
        }
        return false;
    }

    private static int getFieldRefOffset(FieldReference fieldReference) {
        return ((FieldReference.StructField) fieldReference.segments().get(0)).offset();
    }

    private static boolean isOrdered(List<Integer> list) {
        for (int i = 1; i < list.size(); i++) {
            if (list.get(i - 1).intValue() > list.get(i).intValue()) {
                return false;
            }
        }
        return true;
    }
}
