/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.clustering.agglomerativeclustering;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.window.Windows;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SideOutputDataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

public class AgglomerativeClustering
implements AlgoOperator<AgglomerativeClustering>,
AgglomerativeClusteringParams<AgglomerativeClustering> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public AgglomerativeClustering() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        Integer numCluster = this.getNumClusters();
        Double distanceThreshold = this.getDistanceThreshold();
        Preconditions.checkArgument((numCluster == null && distanceThreshold != null || numCluster != null && distanceThreshold == null ? 1 : 0) != 0, (Object)"One of param numCluster and distanceThreshold should be null.");
        if (this.getLinkage().equals("ward")) {
            String distanceMeasure = this.getDistanceMeasure();
            Preconditions.checkArgument((boolean)distanceMeasure.equals("euclidean"), (Object)(distanceMeasure + " was provided as distance measure while linkage was ward. Ward only works with euclidean."));
        }
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream dataStream = tEnv.toDataStream(inputs[0]);
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo((ResolvedSchema)inputs[0].getResolvedSchema());
        RowTypeInfo outputTypeInfo = new RowTypeInfo((TypeInformation[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldTypes(), (Object[])new TypeInformation[]{Types.INT}), (String[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldNames(), (Object[])new String[]{this.getPredictionCol()}));
        OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag = new OutputTag<Tuple4<Integer, Integer, Double, Integer>>("MERGE_INFO"){};
        SingleOutputStreamOperator output = DataStreamUtils.windowAllAndProcess((DataStream)dataStream, (Windows)this.getWindows(), new LocalAgglomerativeClusteringFunction(this.getFeaturesCol(), this.getLinkage(), this.getDistanceMeasure(), this.getNumClusters(), this.getDistanceThreshold(), this.getComputeFullTree(), mergeInfoOutputTag, outputTypeInfo));
        Schema schema = Schema.newBuilder().fromResolvedSchema(inputs[0].getResolvedSchema()).column(this.getPredictionCol(), (AbstractDataType)DataTypes.INT()).build();
        Table outputTable = tEnv.fromDataStream((DataStream)output, schema);
        SideOutputDataStream mergeInfo = output.getSideOutput((OutputTag)mergeInfoOutputTag);
        mergeInfo.getTransformation().setParallelism(1);
        Table mergeInfoTable = tEnv.fromDataStream((DataStream)mergeInfo).as("clusterId1", new String[]{"clusterId2", "distance", "sizeOfMergedCluster"});
        return new Table[]{outputTable, mergeInfoTable};
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static AgglomerativeClustering load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (AgglomerativeClustering)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class LocalAgglomerativeClusteringFunction<W extends Window>
    extends ProcessAllWindowFunction<Row, Row, W>
    implements ResultTypeQueryable<Row> {
        private final String featuresCol;
        private final String linkage;
        private final DistanceMeasure distanceMeasure;
        private final Integer numCluster;
        private final Double distanceThreshold;
        private final boolean computeFullTree;
        private final OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag;
        private final RowTypeInfo outputTypeInfo;

        public LocalAgglomerativeClusteringFunction(String featuresCol, String linkage, String distanceMeasureName, Integer numCluster, Double distanceThreshold, boolean computeFullTree, OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag, RowTypeInfo outputTypeInfo) {
            this.featuresCol = featuresCol;
            this.linkage = linkage;
            this.numCluster = numCluster;
            this.distanceThreshold = distanceThreshold;
            this.computeFullTree = computeFullTree;
            this.mergeInfoOutputTag = mergeInfoOutputTag;
            this.distanceMeasure = DistanceMeasure.getInstance((String)distanceMeasureName);
            this.outputTypeInfo = outputTypeInfo;
        }

        public void process(ProcessAllWindowFunction.Context context, Iterable<Row> values, Collector<Row> output) {
            int i;
            List inputList = IteratorUtils.toList(values.iterator());
            int numDataPoints = inputList.size();
            if (numDataPoints == 0) {
                return;
            }
            DistanceMatrix distanceMatrix = new DistanceMatrix(numDataPoints * 2 - 1);
            for (int i2 = 0; i2 < numDataPoints; ++i2) {
                VectorWithNorm v1 = new VectorWithNorm((Vector)((Row)inputList.get(i2)).getFieldAs(this.featuresCol));
                for (int j = i2 + 1; j < numDataPoints; ++j) {
                    VectorWithNorm v2 = new VectorWithNorm((Vector)((Row)inputList.get(j)).getFieldAs(this.featuresCol));
                    distanceMatrix.set(i2, j, this.distanceMeasure.distance(v1, v2));
                }
            }
            HashSet<Integer> nodeLabels = new HashSet<Integer>(numDataPoints);
            for (int i3 = 0; i3 < numDataPoints; ++i3) {
                nodeLabels.add(i3);
            }
            Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> nnChainAndSize = this.nnChainCore(nodeLabels, distanceMatrix, this.linkage);
            List nnChain = (List)nnChainAndSize.f0;
            nnChain.sort(Comparator.comparingDouble(o -> (Double)o.f3));
            this.reOrderNnChain(nnChain);
            int stoppedIdx = 0;
            if (this.distanceThreshold != null) {
                for (Tuple4 mergeItem : nnChain) {
                    if (!((Double)mergeItem.f3 <= this.distanceThreshold)) continue;
                    ++stoppedIdx;
                }
            } else {
                stoppedIdx = numDataPoints - this.numCluster;
            }
            List<Tuple4<Integer, Integer, Integer, Double>> earlyStoppedNnChain = nnChain.subList(0, stoppedIdx);
            int[] clusterIds = this.label(earlyStoppedNnChain, nnChain.size() + 1);
            HashMap<Integer, Integer> remappedClusterIds = new HashMap<Integer, Integer>();
            int cnt = 0;
            for (i = 0; i < clusterIds.length; ++i) {
                int clusterId = clusterIds[i];
                if (remappedClusterIds.containsKey(clusterId)) {
                    clusterIds[i] = (Integer)remappedClusterIds.get(clusterId);
                    continue;
                }
                clusterIds[i] = cnt;
                remappedClusterIds.put(clusterId, cnt++);
            }
            for (i = 0; i < numDataPoints; ++i) {
                output.collect((Object)Row.join((Row)((Row)inputList.get(i)), (Row[])new Row[]{Row.of((Object[])new Object[]{clusterIds[i]})}));
            }
            if (this.computeFullTree) {
                stoppedIdx = nnChain.size();
            }
            for (i = 0; i < stoppedIdx; ++i) {
                Tuple4 mergeItem = (Tuple4)nnChain.get(i);
                int cid1 = Math.min((Integer)mergeItem.f0, (Integer)mergeItem.f1);
                int cid2 = Math.max((Integer)mergeItem.f0, (Integer)mergeItem.f1);
                context.output(this.mergeInfoOutputTag, (Object)Tuple4.of((Object)cid1, (Object)cid2, (Object)((Double)mergeItem.f3), (Object)(((int[])nnChainAndSize.f1)[cid1] + ((int[])nnChainAndSize.f1)[cid2])));
            }
        }

        private void reOrderNnChain(List<Tuple4<Integer, Integer, Integer, Double>> nnChain) {
            int nextClusterId = nnChain.size() + 1;
            HashMap<Integer, Integer> nodeMapping = new HashMap<Integer, Integer>();
            for (Tuple4<Integer, Integer, Integer, Double> t : nnChain) {
                if (nodeMapping.containsKey(t.f0)) {
                    t.f0 = nodeMapping.get(t.f0);
                }
                if (nodeMapping.containsKey(t.f1)) {
                    t.f1 = nodeMapping.get(t.f1);
                }
                nodeMapping.put((Integer)t.f2, nextClusterId);
                ++nextClusterId;
            }
        }

        private int[] label(List<Tuple4<Integer, Integer, Integer, Double>> nnChains, int numDataPoints) {
            UnionFind unionFind = new UnionFind(numDataPoints);
            for (Tuple4<Integer, Integer, Integer, Double> t : nnChains) {
                unionFind.union(unionFind.find((Integer)t.f0), unionFind.find((Integer)t.f1));
            }
            int[] clusterIds = new int[numDataPoints];
            for (int i = 0; i < clusterIds.length; ++i) {
                clusterIds[i] = unionFind.find(i);
            }
            return clusterIds;
        }

        private Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> nnChainCore(HashSet<Integer> nodeLabels, DistanceMatrix distanceMatrix, String linkage) {
            int numDataPoints;
            int nextClusterId = numDataPoints = nodeLabels.size();
            ArrayList<Tuple4> nnChain = new ArrayList<Tuple4>(numDataPoints);
            ArrayList<Integer> chain = new ArrayList<Integer>();
            int[] size = new int[numDataPoints * 2 - 1];
            for (int i = 0; i < numDataPoints; ++i) {
                size[i] = 1;
            }
            while (nodeLabels.size() > 1) {
                int b;
                int a;
                if (chain.size() <= 3) {
                    Iterator<Integer> iterator = nodeLabels.iterator();
                    a = iterator.next();
                    chain.clear();
                    chain.add(a);
                    b = iterator.next();
                } else {
                    int chainSize = chain.size();
                    a = (Integer)chain.get(chainSize - 4);
                    b = (Integer)chain.get(chainSize - 3);
                    chain.remove(chainSize - 1);
                    chain.remove(chainSize - 2);
                    chain.remove(chainSize - 3);
                }
                while (chain.size() < 3 || (Integer)chain.get(chain.size() - 3) != a) {
                    double minDistance = Double.MAX_VALUE;
                    int c = -1;
                    for (int x : nodeLabels) {
                        double dax;
                        if (x == a || !((dax = distanceMatrix.get(a, x)) < minDistance)) continue;
                        c = x;
                        minDistance = dax;
                    }
                    if (minDistance == distanceMatrix.get(a, b) && nodeLabels.contains(b)) {
                        c = b;
                    }
                    b = a;
                    a = c;
                    chain.add(a);
                }
                int mergedNodeLabel = nextClusterId++;
                nnChain.add(Tuple4.of((Object)a, (Object)b, (Object)mergedNodeLabel, (Object)distanceMatrix.get(a, b)));
                nodeLabels.remove(a);
                nodeLabels.remove(b);
                size[mergedNodeLabel] = size[a] + size[b];
                for (int x : nodeLabels) {
                    double d = this.computeClusterDistances(distanceMatrix.get(a, x), distanceMatrix.get(b, x), distanceMatrix.get(a, b), size[a], size[b], size[x], linkage);
                    distanceMatrix.set(x, mergedNodeLabel, d);
                }
                nodeLabels.add(mergedNodeLabel);
            }
            return Tuple2.of(nnChain, (Object)size);
        }

        private double computeClusterDistances(double dik, double djk, double dij, int si, int sj, int sk, String linkage) {
            switch (linkage) {
                case "single": {
                    return Math.min(dik, djk);
                }
                case "complete": {
                    return Math.max(dik, djk);
                }
                case "average": {
                    return ((double)si * dik + (double)sj * djk) / (double)(si + sj);
                }
                case "ward": {
                    return Math.sqrt(((double)(si + sk) * dik * dik + (double)(sj + sk) * djk * djk - (double)sk * dij * dij) / (double)(si + sj + sk));
                }
            }
            throw new UnsupportedOperationException("Unsupported " + AgglomerativeClusteringParams.LINKAGE + " type: " + linkage + ".");
        }

        public TypeInformation<Row> getProducedType() {
            return this.outputTypeInfo;
        }

        private static class DistanceMatrix {
            private final double[] distances;
            private final int n;

            public DistanceMatrix(int n) {
                this.distances = new double[n * (n - 1) / 2];
                this.n = n;
            }

            public void set(int i, int j, double value) {
                int smallIdx = Math.min(i, j);
                int bigIdx = Math.max(i, j);
                int offset = (this.n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - smallIdx - 1);
                this.distances[offset] = value;
            }

            public double get(int i, int j) {
                int smallIdx = Math.min(i, j);
                int bigIdx = Math.max(i, j);
                int offset = (this.n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - smallIdx - 1);
                return this.distances[offset];
            }
        }

        private static class UnionFind {
            private final int[] parent;
            private int nextLabel;

            public UnionFind(int numDataPoints) {
                this.parent = new int[2 * numDataPoints - 1];
                Arrays.fill(this.parent, -1);
                this.nextLabel = numDataPoints;
            }

            public void union(int m, int n) {
                this.parent[m] = this.nextLabel;
                this.parent[n] = this.nextLabel++;
            }

            public int find(int n) {
                int p = n;
                while (this.parent[n] != -1) {
                    n = this.parent[n];
                }
                while (this.parent[p] != n && this.parent[p] != -1) {
                    p = this.parent[p];
                    this.parent[p] = n;
                }
                return n;
            }
        }
    }
}

