/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.aggregation.builtin;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.util.Optionality;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.builtin.SimpleSqlAggregator;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;

public class SumSqlAggregator
extends SimpleSqlAggregator {
    private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction();

    @Override
    public SqlAggFunction calciteFunction() {
        return DRUID_SUM;
    }

    @Override
    Aggregation getAggregation(String name, AggregateCall aggregateCall, ExprMacroTable macroTable, String fieldName) {
        ColumnType valueType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
        if (valueType == null) {
            return null;
        }
        return Aggregation.create(SumSqlAggregator.createSumAggregatorFactory((ValueType)valueType.getType(), name, fieldName, macroTable));
    }

    static AggregatorFactory createSumAggregatorFactory(ValueType aggregationType, String name, String fieldName, ExprMacroTable macroTable) {
        switch (aggregationType) {
            case LONG: {
                return new LongSumAggregatorFactory(name, fieldName, null, macroTable);
            }
            case FLOAT: {
                return new FloatSumAggregatorFactory(name, fieldName, null, macroTable);
            }
            case DOUBLE: {
                return new DoubleSumAggregatorFactory(name, fieldName, null, macroTable);
            }
        }
        throw new UnsupportedSQLQueryException("Sum aggregation is not supported for '%s' type", aggregationType);
    }

    private static class DruidSumSplitter
    extends SqlSplittableAggFunction.AbstractSumSplitter {
        public static DruidSumSplitter INSTANCE = new DruidSumSplitter();

        private DruidSumSplitter() {
        }

        public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) {
            int arg = (Integer)aggregateCall.getArgList().get(0);
            RelDataTypeField field = (RelDataTypeField)inputRowType.getFieldList().get(arg);
            RexInputRef inputRef = rexBuilder.makeInputRef(field.getType(), arg);
            if (!aggregateCall.getType().equals(field.getType())) {
                return rexBuilder.makeCast(aggregateCall.getType(), (RexNode)inputRef);
            }
            return inputRef;
        }

        protected SqlAggFunction getMergeAggFunctionOfTopSplit() {
            return DRUID_SUM;
        }
    }

    private static class DruidSumAggFunction
    extends SqlAggFunction {
        public DruidSumAggFunction() {
            super("SUM", null, SqlKind.SUM, ReturnTypes.AGG_SUM, null, (SqlOperandTypeChecker)OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC, false, false, Optionality.FORBIDDEN);
        }

        public <T> T unwrap(Class<T> clazz) {
            if (clazz == SqlSplittableAggFunction.class) {
                return clazz.cast((Object)DruidSumSplitter.INSTANCE);
            }
            return (T)super.unwrap(clazz);
        }
    }
}

