/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.imputer;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.util.QuantileSummary;
import org.apache.flink.ml.feature.imputer.ImputerModel;
import org.apache.flink.ml.feature.imputer.ImputerModelData;
import org.apache.flink.ml.feature.imputer.ImputerParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;

public class Imputer
implements Estimator<Imputer, ImputerModel>,
ImputerParams<Imputer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public Imputer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public ImputerModel fit(Table ... inputs) {
        DataStream modelData;
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        Preconditions.checkArgument((this.getInputCols().length == this.getOutputCols().length ? 1 : 0) != 0, (Object)"Num of input columns and output columns are inconsistent.");
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream inputData = tEnv.toDataStream(inputs[0]);
        switch (this.getStrategy()) {
            case "mean": {
                modelData = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new MeanStrategyAggregator(this.getInputCols(), this.getMissingValue()), (TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.LONG})), ImputerModelData.TYPE_INFO);
                break;
            }
            case "median": {
                modelData = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new MedianStrategyAggregator(this.getInputCols(), this.getMissingValue(), this.getRelativeError()), (TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)TypeInformation.of(QuantileSummary.class)), ImputerModelData.TYPE_INFO);
                break;
            }
            case "most_frequent": {
                modelData = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new MostFrequentStrategyAggregator(this.getInputCols(), this.getMissingValue()), (TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.MAP((TypeInformation)Types.DOUBLE, (TypeInformation)Types.LONG)), ImputerModelData.TYPE_INFO);
                break;
            }
            default: {
                throw new RuntimeException("Unsupported strategy of Imputer: " + this.getStrategy());
            }
        }
        Schema schema = Schema.newBuilder().column("surrogates", (AbstractDataType)DataTypes.MAP((DataType)DataTypes.STRING(), (DataType)DataTypes.DOUBLE())).build();
        ImputerModel model = new ImputerModel().setModelData(tEnv.fromDataStream(modelData, schema));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static Imputer load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (Imputer)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class MostFrequentStrategyAggregator
    implements AggregateFunction<Row, Map<String, Map<Double, Long>>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;

        public MostFrequentStrategyAggregator(String[] columnNames, double missingValue) {
            this.columnNames = columnNames;
            this.missingValue = missingValue;
        }

        public Map<String, Map<Double, Long>> createAccumulator() {
            HashMap<String, Map<Double, Long>> accumulators = new HashMap<String, Map<Double, Long>>();
            Arrays.stream(this.columnNames).forEach(x -> accumulators.put((String)x, new HashMap()));
            return accumulators;
        }

        public Map<String, Map<Double, Long>> add(Row row, Map<String, Map<Double, Long>> accumulators) {
            accumulators.forEach((col, counts) -> {
                Double value;
                Object rawValue = row.getField(col);
                if (rawValue != null && !(value = Double.valueOf(rawValue.toString())).equals(this.missingValue) && !value.equals(Double.NaN)) {
                    if (counts.containsKey(value)) {
                        counts.put(value, (Long)counts.get(value) + 1L);
                    } else {
                        counts.put(value, 1L);
                    }
                }
            });
            return accumulators;
        }

        public ImputerModelData getResult(Map<String, Map<Double, Long>> map) {
            long validColumns = map.entrySet().stream().filter(x -> ((Map)x.getValue()).size() > 0).count();
            Preconditions.checkState((validColumns > 0L ? 1 : 0) != 0, (Object)"The training set is empty or does not contains valid data.");
            HashMap<String, Double> surrogates = new HashMap<String, Double>();
            map.forEach((col, counts) -> {
                long maxCnt = Long.MIN_VALUE;
                double value = Double.NaN;
                for (Map.Entry entry : counts.entrySet()) {
                    if (maxCnt > (Long)entry.getValue()) continue;
                    value = maxCnt == (Long)entry.getValue() ? Math.min((Double)entry.getKey(), value) : (Double)entry.getKey();
                    maxCnt = (Long)entry.getValue();
                }
                surrogates.put((String)col, value);
            });
            return new ImputerModelData(surrogates);
        }

        public Map<String, Map<Double, Long>> merge(Map<String, Map<Double, Long>> acc1, Map<String, Map<Double, Long>> acc2) {
            Preconditions.checkArgument((acc1.size() == acc2.size() ? 1 : 0) != 0);
            acc1.forEach((col, counts) -> {
                Map map = (Map)acc2.get(col);
                counts.forEach((value, cnt) -> {
                    if (map.containsKey(value)) {
                        map.put(value, cnt + (Long)map.get(value));
                    } else {
                        map.put(value, cnt);
                    }
                });
            });
            return acc2;
        }
    }

    private static class MedianStrategyAggregator
    implements AggregateFunction<Row, Map<String, QuantileSummary>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;
        private final double relativeError;

        public MedianStrategyAggregator(String[] columnNames, double missingValue, double relativeError) {
            this.columnNames = columnNames;
            this.missingValue = missingValue;
            this.relativeError = relativeError;
        }

        public Map<String, QuantileSummary> createAccumulator() {
            HashMap<String, QuantileSummary> summaries = new HashMap<String, QuantileSummary>();
            Arrays.stream(this.columnNames).forEach(x -> summaries.put((String)x, new QuantileSummary(this.relativeError)));
            return summaries;
        }

        public Map<String, QuantileSummary> add(Row row, Map<String, QuantileSummary> summaries) {
            summaries.forEach((col, summary) -> {
                Double value;
                Object rawValue = row.getField(col);
                if (rawValue != null && !(value = Double.valueOf(rawValue.toString())).equals(this.missingValue) && !value.equals(Double.NaN)) {
                    summary.insert(value);
                }
            });
            return summaries;
        }

        public ImputerModelData getResult(Map<String, QuantileSummary> summaries) {
            HashMap<String, Double> surrogates = new HashMap<String, Double>();
            summaries.forEach((col, summary) -> {
                QuantileSummary compressed = summary.compress();
                if (compressed.isEmpty()) {
                    throw new FlinkRuntimeException(String.format("Surrogate cannot be computed. All the values in column [%s] are null, NaN or missingValue.", col));
                }
                double median = compressed.query(0.5);
                surrogates.put((String)col, median);
            });
            return new ImputerModelData(surrogates);
        }

        public Map<String, QuantileSummary> merge(Map<String, QuantileSummary> acc1, Map<String, QuantileSummary> acc2) {
            Preconditions.checkArgument((acc1.size() == acc2.size() ? 1 : 0) != 0);
            acc1.forEach((col, summary1) -> {
                QuantileSummary summary2 = ((QuantileSummary)acc2.get(col)).compress();
                acc2.put((String)col, summary2.merge(summary1.compress()));
            });
            return acc2;
        }
    }

    private static class MeanStrategyAggregator
    implements AggregateFunction<Row, Map<String, Tuple2<Double, Long>>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;

        public MeanStrategyAggregator(String[] columnNames, double missingValue) {
            this.columnNames = columnNames;
            this.missingValue = missingValue;
        }

        public Map<String, Tuple2<Double, Long>> createAccumulator() {
            HashMap<String, Tuple2<Double, Long>> accumulators = new HashMap<String, Tuple2<Double, Long>>();
            Arrays.stream(this.columnNames).forEach(x -> accumulators.put((String)x, (Tuple2<Double, Long>)Tuple2.of((Object)0.0, (Object)0L)));
            return accumulators;
        }

        public Map<String, Tuple2<Double, Long>> add(Row row, Map<String, Tuple2<Double, Long>> accumulators) {
            accumulators.forEach((col, sumAndNum) -> {
                Double value;
                Object rawValue = row.getField(col);
                if (rawValue != null && !(value = Double.valueOf(rawValue.toString())).equals(this.missingValue) && !value.equals(Double.NaN)) {
                    Tuple2 tuple2 = sumAndNum;
                    tuple2.f0 = (Double)tuple2.f0 + value;
                    tuple2 = sumAndNum;
                    tuple2.f1 = (Long)tuple2.f1 + 1L;
                }
            });
            return accumulators;
        }

        public ImputerModelData getResult(Map<String, Tuple2<Double, Long>> map) {
            long numRows = (Long)((Tuple2)((Map.Entry)map.entrySet().stream().findFirst().get()).getValue()).f1;
            Preconditions.checkState((numRows > 0L ? 1 : 0) != 0, (Object)"The training set is empty or does not contains valid data.");
            HashMap<String, Double> surrogates = new HashMap<String, Double>();
            map.forEach((col, sumAndNum) -> surrogates.put((String)col, (Double)sumAndNum.f0 / (double)((Long)sumAndNum.f1).longValue()));
            return new ImputerModelData(surrogates);
        }

        public Map<String, Tuple2<Double, Long>> merge(Map<String, Tuple2<Double, Long>> acc1, Map<String, Tuple2<Double, Long>> acc2) {
            Preconditions.checkArgument((acc1.size() == acc2.size() ? 1 : 0) != 0);
            acc1.forEach((col, numAndSum) -> {
                Tuple2 tuple2 = (Tuple2)acc2.get(col);
                tuple2.f0 = (Double)tuple2.f0 + (Double)numAndSum.f0;
                tuple2 = (Tuple2)acc2.get(col);
                tuple2.f1 = (Long)tuple2.f1 + (Long)numAndSum.f1;
            });
            return acc2;
        }
    }
}

