/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.clustering.kmeans;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class OnlineKMeans
implements Estimator<OnlineKMeans, OnlineKMeansModel>,
OnlineKMeansParams<OnlineKMeans> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table initModelDataTable;

    public OnlineKMeans() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public OnlineKMeansModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator points = tEnv.toDataStream(inputs[0]).map((MapFunction)new FeaturesExtractor(this.getFeaturesCol()));
        DataStream<KMeansModelData> initModelData = KMeansModelData.getModelDataStream(this.initModelDataTable);
        initModelData.getTransformation().setParallelism(1);
        OnlineKMeansIterationBody body = new OnlineKMeansIterationBody(DistanceMeasure.getInstance((String)this.getDistanceMeasure()), this.getK(), this.getDecayFactor(), this.getGlobalBatchSize());
        DataStream onlineModelData = Iterations.iterateUnboundedStreams((DataStreamList)DataStreamList.of((DataStream[])new DataStream[]{initModelData}), (DataStreamList)DataStreamList.of((DataStream[])new DataStream[]{points}), (IterationBody)body).get(0);
        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        Preconditions.checkNotNull((Object)this.initModelDataTable, (String)"Initial Model Data Table should have been set.");
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
        ReadWriteUtils.saveModelData(KMeansModelData.getModelDataStream(this.initModelDataTable), (String)path, (Encoder)new KMeansModelData.ModelDataEncoder());
    }

    public static OnlineKMeans load(StreamTableEnvironment tEnv, String path) throws IOException {
        OnlineKMeans onlineKMeans = (OnlineKMeans)ReadWriteUtils.loadStageParam((String)path);
        onlineKMeans.initModelDataTable = ReadWriteUtils.loadModelData((StreamTableEnvironment)tEnv, (String)path, (SimpleStreamFormat)new KMeansModelData.ModelDataDecoder());
        return onlineKMeans;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
        this.initModelDataTable = initModelDataTable;
        return this;
    }

    private static class FeaturesExtractor
    implements MapFunction<Row, DenseVector> {
        private final String featuresCol;

        private FeaturesExtractor(String featuresCol) {
            this.featuresCol = featuresCol;
        }

        public DenseVector map(Row row) {
            return ((Vector)row.getField(this.featuresCol)).toDense();
        }
    }

    private static class ModelDataLocalUpdater
    extends AbstractStreamOperator<KMeansModelData>
    implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
        private final DistanceMeasure distanceMeasure;
        private final int k;
        private final double decayFactor;
        private ListState<DenseVector[]> localBatchDataState;
        private ListState<KMeansModelData> modelDataState;

        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) {
            this.distanceMeasure = distanceMeasure;
            this.k = k;
            this.decayFactor = decayFactor;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            ObjectArrayTypeInfo type = ObjectArrayTypeInfo.getInfoFor((TypeInformation)DenseVectorTypeInfo.INSTANCE);
            this.localBatchDataState = context.getOperatorStateStore().getListState(new ListStateDescriptor("localBatch", (TypeInformation)type));
            this.modelDataState = context.getOperatorStateStore().getListState(new ListStateDescriptor("modelData", KMeansModelData.class));
        }

        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
            this.localBatchDataState.add((Object)((DenseVector[])pointsRecord.getValue()));
            this.alignAndComputeModelData();
        }

        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord) throws Exception {
            Preconditions.checkArgument((((KMeansModelData)modelDataRecord.getValue()).centroids.length == this.k ? 1 : 0) != 0);
            this.modelDataState.add((Object)((KMeansModelData)modelDataRecord.getValue()));
            this.alignAndComputeModelData();
        }

        private void alignAndComputeModelData() throws Exception {
            if (!((Iterable)this.modelDataState.get()).iterator().hasNext() || !((Iterable)this.localBatchDataState.get()).iterator().hasNext()) {
                return;
            }
            KMeansModelData modelData = (KMeansModelData)OperatorStateUtils.getUniqueElement(this.modelDataState, (String)"modelData").get();
            DenseVector[] centroids = modelData.centroids;
            VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length];
            for (int i = 0; i < centroidsWithNorm.length; ++i) {
                centroidsWithNorm[i] = new VectorWithNorm((Vector)modelData.centroids[i]);
            }
            DenseVector weights = modelData.weights;
            this.modelDataState.clear();
            List pointsList = IteratorUtils.toList(((Iterable)this.localBatchDataState.get()).iterator());
            DenseVector[] points = (DenseVector[])pointsList.remove(0);
            this.localBatchDataState.update(pointsList);
            int dim = centroids[0].size();
            int parallelism = this.getRuntimeContext().getNumberOfParallelSubtasks();
            DenseVector[] sums = new DenseVector[this.k];
            int[] counts = new int[this.k];
            for (int i = 0; i < this.k; ++i) {
                sums[i] = new DenseVector(dim);
            }
            for (DenseVector point : points) {
                int closestCentroidId;
                int n = closestCentroidId = this.distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm((Vector)point));
                counts[n] = counts[n] + 1;
                BLAS.axpy((double)1.0, (Vector)point, (DenseVector)sums[closestCentroidId]);
            }
            BLAS.scal((double)(this.decayFactor / (double)parallelism), (Vector)weights);
            for (int i = 0; i < this.k; ++i) {
                if (counts[i] == 0) continue;
                DenseVector centroid = centroids[i];
                weights.values[i] = weights.values[i] + (double)counts[i];
                double lambda = (double)counts[i] / weights.values[i];
                BLAS.scal((double)(1.0 - lambda), (Vector)centroid);
                BLAS.axpy((double)(lambda / (double)counts[i]), (Vector)sums[i], (DenseVector)centroid);
            }
            this.output.collect((Object)new StreamRecord((Object)new KMeansModelData(centroids, weights)));
        }
    }

    private static class ModelDataGlobalReducer
    implements ReduceFunction<KMeansModelData> {
        private ModelDataGlobalReducer() {
        }

        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
            DenseVector weights = modelData.weights;
            DenseVector[] centroids = modelData.centroids;
            DenseVector newWeights = newModelData.weights;
            DenseVector[] newCentroids = newModelData.centroids;
            int k = newCentroids.length;
            int dim = newCentroids[0].size();
            for (int i = 0; i < k; ++i) {
                for (int j = 0; j < dim; ++j) {
                    centroids[i].values[j] = (centroids[i].values[j] * weights.values[i] + newCentroids[i].values[j] * newWeights.values[i]) / Math.max(weights.values[i] + newWeights.values[i], 1.0E-16);
                }
                int n = i;
                weights.values[n] = weights.values[n] + newWeights.values[i];
            }
            return new KMeansModelData(centroids, weights);
        }
    }

    private static class OnlineKMeansIterationBody
    implements IterationBody {
        private final DistanceMeasure distanceMeasure;
        private final int k;
        private final double decayFactor;
        private final int batchSize;

        public OnlineKMeansIterationBody(DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
            this.distanceMeasure = distanceMeasure;
            this.k = k;
            this.decayFactor = decayFactor;
            this.batchSize = batchSize;
        }

        public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) {
            DataStream modelData = variableStreams.get(0);
            DataStream points = dataStreams.get(0);
            int parallelism = points.getParallelism();
            Preconditions.checkState((parallelism <= this.batchSize ? 1 : 0) != 0, (Object)"There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.");
            SingleOutputStreamOperator newModelData = DataStreamUtils.generateBatchData((DataStream)points, (int)parallelism, (int)this.batchSize).connect(modelData.broadcast()).transform("ModelDataLocalUpdater", TypeInformation.of(KMeansModelData.class), (TwoInputStreamOperator)new ModelDataLocalUpdater(this.distanceMeasure, this.k, this.decayFactor)).setParallelism(parallelism).countWindowAll((long)parallelism).reduce((ReduceFunction)new ModelDataGlobalReducer());
            return new IterationBodyResult(DataStreamList.of((DataStream[])new DataStream[]{newModelData}), DataStreamList.of((DataStream[])new DataStream[]{modelData}));
        }
    }
}

