/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.naivebayes;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class NaiveBayes
implements Estimator<NaiveBayes, NaiveBayesModel>,
NaiveBayesParams<NaiveBayes> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public NaiveBayes() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public NaiveBayesModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String featuresCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        double smoothing = this.getSmoothing();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator input = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)row -> {
            Number number = (Number)row.getField(labelCol);
            Preconditions.checkNotNull((Object)number, (String)"Input data should contain label value.");
            Preconditions.checkArgument(((double)number.intValue() == number.doubleValue() ? 1 : 0) != 0, (Object)"Label value should be indexed number.");
            return new Tuple2((Object)((Vector)row.getField(featuresCol)), (Object)number.doubleValue());
        }, Types.TUPLE((TypeInformation[])new TypeInformation[]{VectorTypeInfo.INSTANCE, Types.DOUBLE}));
        SingleOutputStreamOperator feature = input.flatMap((FlatMapFunction)new ExtractFeatureFunction());
        DataStream featureWeight = DataStreamUtils.mapPartition((DataStream)feature.keyBy((KeySelector & Serializable)value -> new Tuple2((Object)((Double)value.f0), (Object)((Integer)value.f1)).hashCode()), (MapPartitionFunction)new GenerateFeatureWeightMapFunction(), (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.INT, Types.MAP((TypeInformation)Types.DOUBLE, (TypeInformation)Types.DOUBLE), Types.INT}));
        DataStream aggregatedArrays = DataStreamUtils.mapPartition((DataStream)featureWeight.keyBy((KeySelector & Serializable)value -> (Double)value.f0), (MapPartitionFunction)new AggregateIntoArrayFunction(), (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.INT, Types.OBJECT_ARRAY((TypeInformation)Types.MAP((TypeInformation)Types.DOUBLE, (TypeInformation)Types.DOUBLE))}));
        DataStream modelData = DataStreamUtils.mapPartition((DataStream)aggregatedArrays, (MapPartitionFunction)new GenerateModelFunction(smoothing), NaiveBayesModelData.TYPE_INFO);
        modelData.getTransformation().setParallelism(1);
        Schema schema = Schema.newBuilder().column("theta", (AbstractDataType)DataTypes.ARRAY((DataType)DataTypes.ARRAY((DataType)DataTypes.MAP((DataType)DataTypes.DOUBLE(), (DataType)DataTypes.DOUBLE())))).column("piArray", (AbstractDataType)DataTypes.of((TypeInformation)DenseVectorTypeInfo.INSTANCE)).column("labels", (AbstractDataType)DataTypes.of((TypeInformation)DenseVectorTypeInfo.INSTANCE)).build();
        NaiveBayesModel model = new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static NaiveBayes load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (NaiveBayes)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class GenerateModelFunction
    implements MapPartitionFunction<Tuple3<Double, Integer, Map<Double, Double>[]>, NaiveBayesModelData> {
        private final double smoothing;

        private GenerateModelFunction(double smoothing) {
            this.smoothing = smoothing;
        }

        public void mapPartition(Iterable<Tuple3<Double, Integer, Map<Double, Double>[]>> iterable, Collector<NaiveBayesModelData> collector) {
            ArrayList list = new ArrayList();
            iterable.iterator().forEachRemaining(list::add);
            int featureSize = ((Map[])((Tuple3)list.get((int)0)).f2).length;
            for (Tuple3 tup : list) {
                Preconditions.checkArgument((featureSize == ((Map[])tup.f2).length ? 1 : 0) != 0, (Object)"Feature vectors should be of equal length.");
            }
            double[] numDocs = new double[featureSize];
            HashSet[] categoryNumbers = new HashSet[featureSize];
            for (int i = 0; i < featureSize; ++i) {
                categoryNumbers[i] = new HashSet();
            }
            for (Tuple3 tup : list) {
                for (int i = 0; i < featureSize; ++i) {
                    int n = i;
                    numDocs[n] = numDocs[n] + (double)((Integer)tup.f1).intValue();
                    categoryNumbers[i].addAll(((Map[])tup.f2)[i].keySet());
                }
            }
            int[] categoryNumber = new int[featureSize];
            double piLog = 0.0;
            int numLabels = list.size();
            for (int i = 0; i < featureSize; ++i) {
                categoryNumber[i] = categoryNumbers[i].size();
                piLog += numDocs[i];
            }
            piLog = Math.log(piLog + (double)numLabels * this.smoothing);
            HashMap[][] theta = new HashMap[numLabels][featureSize];
            double[] piArray = new double[numLabels];
            double[] labels = new double[numLabels];
            for (int i = 0; i < numLabels; ++i) {
                Map[] param = (Map[])((Tuple3)list.get((int)i)).f2;
                for (int j = 0; j < featureSize; ++j) {
                    HashMap<Double, Double> squareData = new HashMap<Double, Double>();
                    double thetaLog = Math.log((double)((Integer)((Tuple3)list.get((int)i)).f1).intValue() * 1.0 + this.smoothing * (double)categoryNumber[j]);
                    for (Double cate : categoryNumbers[j]) {
                        double value = param[j].getOrDefault(cate, 0.0);
                        squareData.put(cate, Math.log(value + this.smoothing) - thetaLog);
                    }
                    theta[i][j] = squareData;
                }
                labels[i] = (Double)((Tuple3)list.get((int)i)).f0;
                double weightSum = (Integer)((Tuple3)list.get((int)i)).f1 * featureSize;
                piArray[i] = Math.log(weightSum + this.smoothing) - piLog;
            }
            NaiveBayesModelData modelData = new NaiveBayesModelData(theta, Vectors.dense((double[])piArray), Vectors.dense((double[])labels));
            collector.collect((Object)modelData);
        }
    }

    private static class AggregateIntoArrayFunction
    implements MapPartitionFunction<Tuple4<Double, Integer, Map<Double, Double>, Integer>, Tuple3<Double, Integer, Map<Double, Double>[]>> {
        private AggregateIntoArrayFunction() {
        }

        public void mapPartition(Iterable<Tuple4<Double, Integer, Map<Double, Double>, Integer>> iterable, Collector<Tuple3<Double, Integer, Map<Double, Double>[]>> collector) {
            HashMap<Double, List> map = new HashMap<Double, List>();
            for (Tuple4<Double, Integer, Map<Double, Double>, Integer> value : iterable) {
                map.computeIfAbsent((Double)value.f0, x -> new ArrayList()).add(value);
            }
            for (List list : map.values()) {
                int maxDocNum;
                int featureSize = list.stream().map(x -> (Integer)x.f1).max(Integer::compareTo).orElse(-1) + 1;
                int minDocNum = list.stream().map(x -> (Integer)x.f3).min(Integer::compareTo).orElse(Integer.MAX_VALUE);
                Preconditions.checkArgument((minDocNum == (maxDocNum = list.stream().map(x -> (Integer)x.f3).max(Integer::compareTo).orElse(Integer.MIN_VALUE).intValue()) ? 1 : 0) != 0, (Object)"Feature vectors should be of equal length.");
                HashMap<Double, Integer> numMap = new HashMap<Double, Integer>();
                HashMap<Double, Map[]> featureWeightMap = new HashMap<Double, Map[]>();
                for (Tuple4 value : list) {
                    Map[] featureWeight = featureWeightMap.computeIfAbsent((Double)value.f0, x -> new HashMap[featureSize]);
                    numMap.put((Double)value.f0, (Integer)value.f3);
                    featureWeight[((Integer)value.f1).intValue()] = (Map)value.f2;
                }
                Iterator<Object> iterator = featureWeightMap.keySet().iterator();
                while (iterator.hasNext()) {
                    double key = (Double)iterator.next();
                    collector.collect((Object)new Tuple3((Object)key, (Object)((Integer)numMap.get(key)), (Object)((Map[])featureWeightMap.get(key))));
                }
            }
        }
    }

    private static class GenerateFeatureWeightMapFunction
    implements MapPartitionFunction<Tuple3<Double, Integer, Double>, Tuple4<Double, Integer, Map<Double, Double>, Integer>> {
        private GenerateFeatureWeightMapFunction() {
        }

        public void mapPartition(Iterable<Tuple3<Double, Integer, Double>> iterable, Collector<Tuple4<Double, Integer, Map<Double, Double>, Integer>> collector) {
            ArrayList list = new ArrayList();
            iterable.iterator().forEachRemaining(list::add);
            HashMap<Tuple2, Map> accMap = new HashMap<Tuple2, Map>();
            HashMap<Tuple2, Integer> numMap = new HashMap<Tuple2, Integer>();
            for (Tuple3 tuple3 : list) {
                Tuple2 key = new Tuple2((Object)((Double)tuple3.f0), (Object)((Integer)tuple3.f1));
                Map acc = accMap.computeIfAbsent(key, x -> new HashMap());
                acc.put((Double)tuple3.f2, acc.getOrDefault(tuple3.f2, 0.0) + 1.0);
                numMap.put(key, numMap.getOrDefault(key, 0) + 1);
            }
            for (Map.Entry entry : accMap.entrySet()) {
                collector.collect((Object)new Tuple4((Object)((Double)((Tuple2)entry.getKey()).f0), (Object)((Integer)((Tuple2)entry.getKey()).f1), (Object)((Map)entry.getValue()), (Object)((Integer)numMap.get(entry.getKey()))));
            }
        }
    }

    private static class ExtractFeatureFunction
    implements FlatMapFunction<Tuple2<Vector, Double>, Tuple3<Double, Integer, Double>> {
        private ExtractFeatureFunction() {
        }

        public void flatMap(Tuple2<Vector, Double> value, Collector<Tuple3<Double, Integer, Double>> collector) {
            Preconditions.checkNotNull((Object)((Double)value.f1));
            for (int i = 0; i < ((Vector)value.f0).size(); ++i) {
                collector.collect((Object)new Tuple3((Object)((Double)value.f1), (Object)i, (Object)((Vector)value.f0).get(i)));
            }
        }
    }
}

