/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.clustering.kmeans;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.KMeansParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class KMeans
implements Estimator<KMeans, KMeansModel>,
KMeansParams<KMeans> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public KMeans() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public KMeansModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator points = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)row -> ((Vector)row.getField(this.getFeaturesCol())).toDense());
        DataStream<DenseVector[]> initCentroids = KMeans.selectRandomCentroids((DataStream<DenseVector>)points, this.getK(), this.getSeed());
        IterationConfig config = IterationConfig.newBuilder().setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND).build();
        KMeansIterationBody body = new KMeansIterationBody(this.getMaxIter(), DistanceMeasure.getInstance((String)this.getDistanceMeasure()));
        DataStream finalModelData = Iterations.iterateBoundedStreamsUntilTermination((DataStreamList)DataStreamList.of((DataStream[])new DataStream[]{initCentroids}), (ReplayableDataStreamList)ReplayableDataStreamList.notReplay((DataStream[])new DataStream[]{points}), (IterationConfig)config, (IterationBody)body).get(0);
        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
        KMeansModel model = new KMeansModel().setModelData(finalModelDataTable);
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static KMeans load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (KMeans)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static DataStream<DenseVector[]> selectRandomCentroids(DataStream<DenseVector> data, int k, long seed) {
        DataStream resultStream = DataStreamUtils.mapPartition((DataStream)DataStreamUtils.sample(data, (int)k, (long)seed), (MapPartitionFunction)new MapPartitionFunction<DenseVector, DenseVector[]>(){

            public void mapPartition(Iterable<DenseVector> iterable, Collector<DenseVector[]> collector) {
                ArrayList list = new ArrayList();
                iterable.iterator().forEachRemaining(list::add);
                collector.collect((Object)list.toArray(new DenseVector[0]));
            }
        });
        resultStream.getTransformation().setParallelism(1);
        return resultStream;
    }

    private static class CentroidsUpdateAccumulator
    extends AbstractStreamOperator<Tuple2<Integer[], DenseVector[]>>
    implements TwoInputStreamOperator<DenseVector, DenseVector[], Tuple2<Integer[], DenseVector[]>>,
    IterationListener<Tuple2<Integer[], DenseVector[]>> {
        private final DistanceMeasure distanceMeasure;
        private ListState<DenseVector[]> centroids;
        private ListStateWithCache<VectorWithNorm> points;

        public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) {
            this.distanceMeasure = distanceMeasure;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            ObjectArrayTypeInfo type = ObjectArrayTypeInfo.getInfoFor((TypeInformation)DenseVectorTypeInfo.INSTANCE);
            this.centroids = context.getOperatorStateStore().getListState(new ListStateDescriptor("centroids", (TypeInformation)type));
            this.points = new ListStateWithCache((TypeSerializer)new VectorWithNormSerializer(), this.getContainingTask(), this.getRuntimeContext(), context, this.config.getOperatorID());
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.points.snapshotState(context);
        }

        public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
            this.points.add((Object)new VectorWithNorm((Vector)streamRecord.getValue()));
        }

        public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            Preconditions.checkState((!((Iterable)this.centroids.get()).iterator().hasNext() ? 1 : 0) != 0);
            this.centroids.add((Object)((DenseVector[])streamRecord.getValue()));
        }

        public void onEpochWatermarkIncremented(int epochWatermark, IterationListener.Context context, Collector<Tuple2<Integer[], DenseVector[]>> out) throws Exception {
            DenseVector[] centroidValues = Objects.requireNonNull(OperatorStateUtils.getUniqueElement(this.centroids, (String)"centroids").orElse(null));
            VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[centroidValues.length];
            for (int i = 0; i < centroidsWithNorm.length; ++i) {
                centroidsWithNorm[i] = new VectorWithNorm((Vector)centroidValues[i]);
            }
            DenseVector[] newCentroids = new DenseVector[centroidValues.length];
            Object[] counts = new Integer[centroidValues.length];
            Arrays.fill(counts, (Object)0);
            for (int i = 0; i < centroidValues.length; ++i) {
                newCentroids[i] = new DenseVector(centroidValues[i].size());
            }
            for (VectorWithNorm point : this.points.get()) {
                int closestCentroidId = this.distanceMeasure.findClosest(centroidsWithNorm, point);
                BLAS.axpy((double)1.0, (Vector)point.vector, (DenseVector)newCentroids[closestCentroidId]);
                Object[] objectArray = counts;
                int n = closestCentroidId;
                Object object = objectArray[n];
                objectArray[n] = (Integer)objectArray[n] + 1;
                Integer n2 = objectArray[n];
            }
            this.output.collect((Object)new StreamRecord((Object)Tuple2.of((Object)counts, (Object)newCentroids)));
            this.centroids.clear();
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) {
            this.centroids.clear();
            this.points.clear();
        }
    }

    private static class ModelDataGenerator
    implements MapFunction<Tuple2<Integer[], DenseVector[]>, KMeansModelData> {
        private ModelDataGenerator() {
        }

        public KMeansModelData map(Tuple2<Integer[], DenseVector[]> tuple2) throws Exception {
            double[] weights = new double[((Integer[])tuple2.f0).length];
            for (int i = 0; i < ((Integer[])tuple2.f0).length; ++i) {
                BLAS.scal((double)(1.0 / (double)((Integer[])tuple2.f0)[i].intValue()), (Vector)((DenseVector[])tuple2.f1)[i]);
                weights[i] = ((Integer[])tuple2.f0)[i].intValue();
            }
            return new KMeansModelData((DenseVector[])tuple2.f1, new DenseVector(weights));
        }
    }

    private static class CentroidsUpdateReducer
    implements ReduceFunction<Tuple2<Integer[], DenseVector[]>> {
        private CentroidsUpdateReducer() {
        }

        public Tuple2<Integer[], DenseVector[]> reduce(Tuple2<Integer[], DenseVector[]> tuple2, Tuple2<Integer[], DenseVector[]> t1) throws Exception {
            for (int i = 0; i < ((Integer[])tuple2.f0).length; ++i) {
                Integer[] integerArray = (Integer[])tuple2.f0;
                int n = i;
                Integer.valueOf(integerArray[n] + ((Integer[])t1.f0)[i]);
                BLAS.axpy((double)1.0, (Vector)((DenseVector[])t1.f1)[i], (DenseVector)((DenseVector[])tuple2.f1)[i]);
            }
            return tuple2;
        }
    }

    private static class KMeansIterationBody
    implements IterationBody {
        private final int maxIterationNum;
        private final DistanceMeasure distanceMeasure;

        public KMeansIterationBody(int maxIterationNum, DistanceMeasure distanceMeasure) {
            this.maxIterationNum = maxIterationNum;
            this.distanceMeasure = distanceMeasure;
        }

        public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) {
            DataStream centroids = variableStreams.get(0);
            DataStream points = dataStreams.get(0);
            SingleOutputStreamOperator terminationCriteria = centroids.flatMap((FlatMapFunction)new TerminateOnMaxIter(this.maxIterationNum));
            SingleOutputStreamOperator centroidIdAndPoints = points.connect(centroids.broadcast()).transform("CentroidsUpdateAccumulator", (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, ObjectArrayTypeInfo.getInfoFor((TypeInformation)DenseVectorTypeInfo.INSTANCE)}), (TwoInputStreamOperator)new CentroidsUpdateAccumulator(this.distanceMeasure));
            DataStreamUtils.setManagedMemoryWeight((DataStream)centroidIdAndPoints, (long)100L);
            int parallelism = centroidIdAndPoints.getParallelism();
            SingleOutputStreamOperator newModelData = centroidIdAndPoints.countWindowAll((long)parallelism).reduce((ReduceFunction)new CentroidsUpdateReducer()).map((MapFunction)new ModelDataGenerator());
            SingleOutputStreamOperator newCentroids = newModelData.map((MapFunction & Serializable)x -> x.centroids).setParallelism(1);
            SingleOutputStreamOperator finalModelData = newModelData.flatMap((FlatMapFunction)new ForwardInputsOfLastRound());
            return new IterationBodyResult(DataStreamList.of((DataStream[])new DataStream[]{newCentroids}), DataStreamList.of((DataStream[])new DataStream[]{finalModelData}), (DataStream)terminationCriteria);
        }
    }
}

