/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.countvectorizer;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModelData;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class CountVectorizer
implements Estimator<CountVectorizer, CountVectorizerModel>,
CountVectorizerParams<CountVectorizer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public CountVectorizer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public CountVectorizerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        double minDF = this.getMinDF();
        double maxDF = this.getMaxDF();
        if (minDF >= 1.0 && maxDF >= 1.0 || minDF < 1.0 && maxDF < 1.0) {
            Preconditions.checkArgument((maxDF >= minDF ? 1 : 0) != 0, (Object)"maxDF must be >= minDF.");
        }
        String inputCol = this.getInputCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)value -> (String[])value.getField(inputCol));
        DataStream modelData = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new VocabularyAggregator(this.getMinDF(), this.getMaxDF(), this.getVocabularySize()), (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.LONG, Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.LONG, Types.LONG}))}), (TypeInformation)TypeInformation.of(CountVectorizerModelData.class));
        CountVectorizerModel model = new CountVectorizerModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static CountVectorizer load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (CountVectorizer)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class VocabularyAggregator
    implements AggregateFunction<String[], Tuple2<Long, Map<String, Tuple2<Long, Long>>>, CountVectorizerModelData> {
        private final double minDF;
        private final double maxDF;
        private final int vocabularySize;

        public VocabularyAggregator(double minDF, double maxDF, int vocabularySize) {
            this.minDF = minDF;
            this.maxDF = maxDF;
            this.vocabularySize = vocabularySize;
        }

        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> createAccumulator() {
            return Tuple2.of((Object)0L, new HashMap());
        }

        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> add(String[] terms, Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
            HashMap<String, Long> wc = new HashMap<String, Long>();
            Arrays.stream(terms).forEach(x -> {
                if (wc.containsKey(x)) {
                    wc.put((String)x, (Long)wc.get(x) + 1L);
                } else {
                    wc.put((String)x, 1L);
                }
            });
            Map counts = (Map)vocabAccumulator.f1;
            wc.forEach((w, c) -> {
                if (counts.containsKey(w)) {
                    Tuple2 tuple2 = (Tuple2)counts.get(w);
                    tuple2.f0 = (Long)tuple2.f0 + c;
                    tuple2 = (Tuple2)counts.get(w);
                    tuple2.f1 = (Long)tuple2.f1 + 1L;
                } else {
                    counts.put(w, Tuple2.of((Object)c, (Object)1L));
                }
            });
            Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple2 = vocabAccumulator;
            tuple2.f0 = (Long)tuple2.f0 + 1L;
            return vocabAccumulator;
        }

        public CountVectorizerModelData getResult(Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
            boolean filteringRequired;
            Preconditions.checkState(((Long)vocabAccumulator.f0 > 0L ? 1 : 0) != 0, (Object)"The training set is empty.");
            boolean bl = filteringRequired = !((Double)CountVectorizerParams.MIN_DF.defaultValue).equals(this.minDF) || !((Double)CountVectorizerParams.MAX_DF.defaultValue).equals(this.maxDF);
            if (filteringRequired) {
                long rowNum = (Long)vocabAccumulator.f0;
                double actualMinDF = this.minDF >= 1.0 ? this.minDF : this.minDF * (double)rowNum;
                double actualMaxDF = this.maxDF >= 1.0 ? this.maxDF : this.maxDF * (double)rowNum;
                Preconditions.checkState((actualMaxDF >= actualMinDF ? 1 : 0) != 0, (Object)"maxDF must be >= minDF.");
                vocabAccumulator.f1 = ((Map)vocabAccumulator.f1).entrySet().stream().filter(x -> (double)((Long)((Tuple2)x.getValue()).f1).longValue() >= actualMinDF && (double)((Long)((Tuple2)x.getValue()).f1).longValue() <= actualMaxDF).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            }
            ArrayList list = new ArrayList(((Map)vocabAccumulator.f1).entrySet());
            list.sort((o1, o2) -> ((Long)((Tuple2)o2.getValue()).f1).compareTo((Long)((Tuple2)o1.getValue()).f1));
            List vocabulary = list.stream().map(Map.Entry::getKey).collect(Collectors.toList());
            String[] topTerms = vocabulary.subList(0, Math.min(vocabulary.size(), this.vocabularySize)).toArray(new String[0]);
            return new CountVectorizerModelData(topTerms);
        }

        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> merge(Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc1, Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc2) {
            if ((Long)acc1.f0 == 0L) {
                return acc2;
            }
            if ((Long)acc2.f0 == 0L) {
                return acc1;
            }
            Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple2 = acc2;
            tuple2.f0 = (Long)tuple2.f0 + (Long)acc1.f0;
            ((Map)acc1.f1).forEach((term, counts) -> {
                if (((Map)acc2.f1).containsKey(term)) {
                    ((Map)acc2.f1).put(term, Tuple2.of((Object)((Long)counts.f0 + (Long)((Tuple2)((Map)acc2.f1).get((Object)term)).f0), (Object)((Long)counts.f1 + (Long)((Tuple2)((Map)acc2.f1).get((Object)term)).f1)));
                } else {
                    ((Map)acc2.f1).put(term, counts);
                }
            });
            return acc2;
        }
    }
}

