/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.idf;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.idf.IDFModel;
import org.apache.flink.ml.feature.idf.IDFModelData;
import org.apache.flink.ml.feature.idf.IDFParams;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class IDF
implements Estimator<IDF, IDFModel>,
IDFParams<IDF> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public IDF() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public IDFModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String inputCol = this.getInputCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)value -> (Vector)value.getField(inputCol), (TypeInformation)VectorTypeInfo.INSTANCE);
        DataStream modelData = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new IDFAggregator(this.getMinDocFreq()));
        IDFModel model = new IDFModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static IDF load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (IDF)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class IDFAggregator
    implements AggregateFunction<Vector, Tuple2<Long, DenseVector>, IDFModelData> {
        private final int minDocFreq;

        public IDFAggregator(int minDocFreq) {
            this.minDocFreq = minDocFreq;
        }

        public Tuple2<Long, DenseVector> createAccumulator() {
            return Tuple2.of((Object)0L, (Object)new DenseVector(new double[0]));
        }

        public Tuple2<Long, DenseVector> add(Vector vector, Tuple2<Long, DenseVector> numDocsAndDocFreq) {
            if ((Long)numDocsAndDocFreq.f0 == 0L) {
                numDocsAndDocFreq.f1 = new DenseVector(vector.size());
            }
            Tuple2<Long, DenseVector> tuple2 = numDocsAndDocFreq;
            tuple2.f0 = (Long)tuple2.f0 + 1L;
            double[] values = vector instanceof SparseVector ? ((SparseVector)vector).values : ((DenseVector)vector).values;
            for (int i = 0; i < values.length; ++i) {
                values[i] = values[i] > 0.0 ? 1.0 : 0.0;
            }
            BLAS.axpy((double)1.0, (Vector)vector, (DenseVector)((DenseVector)numDocsAndDocFreq.f1));
            return numDocsAndDocFreq;
        }

        public IDFModelData getResult(Tuple2<Long, DenseVector> numDocsAndDocFreq) {
            long numDocs = (Long)numDocsAndDocFreq.f0;
            DenseVector docFreq = (DenseVector)numDocsAndDocFreq.f1;
            Preconditions.checkState((numDocs > 0L ? 1 : 0) != 0, (Object)"The training set is empty.");
            long[] filteredDocFreq = new long[docFreq.size()];
            double[] df = docFreq.values;
            double[] idf = new double[df.length];
            for (int i = 0; i < idf.length; ++i) {
                if (!(df[i] >= (double)this.minDocFreq)) continue;
                idf[i] = Math.log((double)(numDocs + 1L) / (df[i] + 1.0));
                filteredDocFreq[i] = (long)df[i];
            }
            return new IDFModelData(Vectors.dense((double[])idf), filteredDocFreq, numDocs);
        }

        public Tuple2<Long, DenseVector> merge(Tuple2<Long, DenseVector> numDocsAndDocFreq1, Tuple2<Long, DenseVector> numDocsAndDocFreq2) {
            if ((Long)numDocsAndDocFreq1.f0 == 0L) {
                return numDocsAndDocFreq2;
            }
            if ((Long)numDocsAndDocFreq2.f0 == 0L) {
                return numDocsAndDocFreq1;
            }
            Tuple2<Long, DenseVector> tuple2 = numDocsAndDocFreq2;
            tuple2.f0 = (Long)tuple2.f0 + (Long)numDocsAndDocFreq1.f0;
            BLAS.axpy((double)1.0, (Vector)((Vector)numDocsAndDocFreq1.f1), (DenseVector)((DenseVector)numDocsAndDocFreq2.f1));
            return numDocsAndDocFreq2;
        }
    }
}

