/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.kbinsdiscretizer;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KBinsDiscretizer
implements Estimator<KBinsDiscretizer, KBinsDiscretizerModel>,
KBinsDiscretizerParams<KBinsDiscretizer> {
    private static final Logger LOG = LoggerFactory.getLogger(KBinsDiscretizer.class);
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public KBinsDiscretizer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public KBinsDiscretizerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        String inputCol = this.getInputCol();
        final String strategy = this.getStrategy();
        final int numBins = this.getNumBins();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)value -> ((Vector)value.getField(inputCol)).toDense());
        Object preprocessedData = strategy.equals("uniform") ? inputData.transform("reduceInEachPartition", inputData.getType(), (OneInputStreamOperator)new MinMaxScaler.MinMaxReduceFunctionOperator()).transform("reduceInFinalPartition", inputData.getType(), (OneInputStreamOperator)new MinMaxScaler.MinMaxReduceFunctionOperator()).setParallelism(1) : DataStreamUtils.sample((DataStream)inputData, (int)this.getSubSamples(), (long)this.getClass().getName().hashCode());
        DataStream modelData = DataStreamUtils.mapPartition((DataStream)preprocessedData, (MapPartitionFunction)new MapPartitionFunction<DenseVector, KBinsDiscretizerModelData>(){

            public void mapPartition(Iterable<DenseVector> iterable, Collector<KBinsDiscretizerModelData> collector) {
                double[][] binEdges;
                ArrayList list = new ArrayList();
                iterable.iterator().forEachRemaining(list::add);
                if (list.size() == 0) {
                    throw new RuntimeException("The training set is empty.");
                }
                switch (strategy) {
                    case "uniform": {
                        binEdges = KBinsDiscretizer.findBinEdgesWithUniformStrategy(list, numBins);
                        break;
                    }
                    case "quantile": {
                        binEdges = KBinsDiscretizer.findBinEdgesWithQuantileStrategy(list, numBins);
                        break;
                    }
                    case "kmeans": {
                        binEdges = KBinsDiscretizer.findBinEdgesWithKMeansStrategy(list, numBins);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported " + KBinsDiscretizerParams.STRATEGY + " type: " + strategy + ".");
                    }
                }
                collector.collect((Object)new KBinsDiscretizerModelData(binEdges));
            }
        });
        modelData.getTransformation().setParallelism(1);
        KBinsDiscretizerModel model = new KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (KBinsDiscretizer)ReadWriteUtils.loadStageParam((String)path);
    }

    private static double[][] findBinEdgesWithUniformStrategy(List<DenseVector> input, int numBins) {
        DenseVector minVector = input.get(0);
        DenseVector maxVector = input.get(1);
        int numColumns = minVector.size();
        double[][] binEdges = new double[numColumns][];
        for (int columnId = 0; columnId < numColumns; ++columnId) {
            double max;
            double min = minVector.get(columnId);
            if (min == (max = maxVector.get(columnId))) {
                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
                binEdges[columnId] = new double[]{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
                continue;
            }
            double width = (max - min) / (double)numBins;
            binEdges[columnId] = new double[numBins + 1];
            binEdges[columnId][0] = min;
            for (int edgeId = 1; edgeId < numBins + 1; ++edgeId) {
                binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
            }
        }
        return binEdges;
    }

    private static double[][] findBinEdgesWithQuantileStrategy(List<DenseVector> input, int numBins) {
        int numColumns = input.get(0).size();
        int numData = input.size();
        double[][] binEdges = new double[numColumns][];
        double[] features = new double[numData];
        for (int columnId = 0; columnId < numColumns; ++columnId) {
            int i;
            double edge;
            int binEdgeId;
            double[] tempBinEdges;
            for (int i2 = 0; i2 < numData; ++i2) {
                features[i2] = input.get(i2).get(columnId);
            }
            Arrays.sort(features);
            if (features[0] == features[numData - 1]) {
                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
                binEdges[columnId] = new double[]{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
                continue;
            }
            if (features.length > numBins) {
                double width = 1.0 * (double)features.length / (double)numBins;
                tempBinEdges = new double[numBins + 1];
                for (binEdgeId = 0; binEdgeId < numBins; ++binEdgeId) {
                    tempBinEdges[binEdgeId] = features[(int)((double)binEdgeId * width)];
                }
                tempBinEdges[numBins] = features[numData - 1];
            } else {
                tempBinEdges = features;
            }
            HashMap<Double, Integer> edgesAndCnt = new HashMap<Double, Integer>(numBins);
            double[] dArray = tempBinEdges;
            binEdgeId = dArray.length;
            for (int j = 0; j < binEdgeId; ++j) {
                edge = dArray[j];
                edgesAndCnt.put(edge, edgesAndCnt.getOrDefault(edge, 0) + 1);
            }
            ArrayList<Double> edges = new ArrayList<Double>();
            for (Map.Entry edgeAndCnt : edgesAndCnt.entrySet()) {
                edge = (Double)edgeAndCnt.getKey();
                int cnt = (Integer)edgeAndCnt.getValue();
                edges.add(edge);
                if (cnt <= 1) continue;
                edges.add(edge);
            }
            tempBinEdges = edges.stream().mapToDouble(Double::doubleValue).toArray();
            Arrays.sort(tempBinEdges);
            for (i = 1; i < tempBinEdges.length - 1; ++i) {
                if (tempBinEdges[i] != tempBinEdges[i - 1]) continue;
                tempBinEdges[i] = (tempBinEdges[i + 1] + tempBinEdges[i - 1]) / 2.0;
            }
            if (tempBinEdges[i] == tempBinEdges[i - 1]) {
                tempBinEdges[i - 1] = (tempBinEdges[i] + tempBinEdges[i - 2]) / 2.0;
            }
            binEdges[columnId] = tempBinEdges;
        }
        return binEdges;
    }

    private static double[][] findBinEdgesWithKMeansStrategy(List<DenseVector> input, int numBins) {
        int numColumns = input.get(0).size();
        int numData = input.size();
        double[][] binEdges = new double[numColumns][numBins + 1];
        double[] features = new double[numData];
        double[] kMeansCentroids = new double[numBins];
        double[] sumByCluster = new double[numBins];
        for (int columnId = 0; columnId < numColumns; ++columnId) {
            for (int i = 0; i < numData; ++i) {
                features[i] = input.get(i).get(columnId);
            }
            Arrays.sort(features);
            if (features[0] == features[numData - 1]) {
                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
                binEdges[columnId] = new double[]{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
                continue;
            }
            HashSet<Double> distinctFeatureValues = new HashSet<Double>(numBins + 1);
            for (double feature : features) {
                distinctFeatureValues.add(feature);
                if (distinctFeatureValues.size() >= numBins + 1) break;
            }
            if (distinctFeatureValues.size() <= numBins) {
                double min = features[0];
                double max = features[features.length - 1];
                double width = (max - min) / (double)numBins;
                binEdges[columnId] = new double[numBins + 1];
                binEdges[columnId][0] = min;
                for (int edgeId = 1; edgeId < numBins + 1; ++edgeId) {
                    binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
                }
                continue;
            }
            double width = 1.0 * (double)features.length / (double)numBins;
            for (int clusterId = 0; clusterId < numBins; ++clusterId) {
                kMeansCentroids[clusterId] = features[(int)((double)clusterId * width)];
            }
            double tolerance = 1.0E-4;
            int maxIterations = 300;
            double oldLoss = Double.MAX_VALUE;
            double relativeLoss = Double.MAX_VALUE;
            int[] countByCluster = new int[numBins];
            for (int iter = 0; iter < 300 && relativeLoss > 1.0E-4; ++iter) {
                double loss = 0.0;
                for (double featureValue : features) {
                    double minDistance = Math.abs(kMeansCentroids[0] - featureValue);
                    int clusterId = 0;
                    for (int i = 1; i < kMeansCentroids.length; ++i) {
                        double distance = Math.abs(kMeansCentroids[i] - featureValue);
                        if (!(distance < minDistance)) continue;
                        minDistance = distance;
                        clusterId = i;
                    }
                    int n = clusterId;
                    countByCluster[n] = countByCluster[n] + 1;
                    int n2 = clusterId;
                    sumByCluster[n2] = sumByCluster[n2] + featureValue;
                    loss += minDistance;
                }
                for (int clusterId = 0; clusterId < kMeansCentroids.length; ++clusterId) {
                    kMeansCentroids[clusterId] = sumByCluster[clusterId] / (double)countByCluster[clusterId];
                }
                relativeLoss = Math.abs((loss /= (double)features.length) - oldLoss);
                oldLoss = loss;
                Arrays.fill(sumByCluster, 0.0);
                Arrays.fill(countByCluster, 0);
            }
            Arrays.sort(kMeansCentroids);
            binEdges[columnId] = new double[numBins + 1];
            binEdges[columnId][0] = features[0];
            binEdges[columnId][numBins] = features[features.length - 1];
            for (int binEdgeId = 1; binEdgeId < numBins; ++binEdgeId) {
                binEdges[columnId][binEdgeId] = (kMeansCentroids[binEdgeId - 1] + kMeansCentroids[binEdgeId]) / 2.0;
            }
        }
        return binEdges;
    }
}

