/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.stats.anovatest;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.stats.anovatest.ANOVATestParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class ANOVATest
implements AlgoOperator<ANOVATest>,
ANOVATestParams<ANOVATest> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public ANOVATest() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String featuresCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)row -> {
            Number number = (Number)row.getField(labelCol);
            Preconditions.checkNotNull((Object)number, (String)"Input data must contain label value.");
            return new Tuple2((Object)((Vector)row.getField(featuresCol)), (Object)number.doubleValue());
        }, Types.TUPLE((TypeInformation[])new TypeInformation[]{VectorTypeInfo.INSTANCE, Types.DOUBLE}));
        DataStream streamWithANOVA = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new ANOVAAggregator(), (TypeInformation)Types.OBJECT_ARRAY((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.DOUBLE, Types.MAP((TypeInformation)Types.DOUBLE, (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.LONG}))})), (TypeInformation)Types.LIST((TypeInformation)Types.ROW((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE})));
        return new Table[]{this.convertToTable(tEnv, (DataStream<List<Row>>)streamWithANOVA, this.getFlatten())};
    }

    private Table convertToTable(StreamTableEnvironment tEnv, DataStream<List<Row>> datastream, boolean flatten) {
        if (flatten) {
            SingleOutputStreamOperator output = datastream.flatMap((FlatMapFunction & Serializable)(list, collector) -> list.forEach(arg_0 -> ((Collector)collector).collect(arg_0))).setParallelism(1).returns(Types.ROW((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE}));
            return tEnv.fromDataStream((DataStream)output).as("featureIndex", new String[]{"pValue", "degreeOfFreedom", "fValue"});
        }
        SingleOutputStreamOperator output = datastream.map((MapFunction)new MapFunction<List<Row>, Tuple3<DenseVector, long[], DenseVector>>(){

            public Tuple3<DenseVector, long[], DenseVector> map(List<Row> rows) {
                int numOfFeatures = rows.size();
                DenseVector pValues = new DenseVector(numOfFeatures);
                DenseVector fValues = new DenseVector(numOfFeatures);
                long[] degrees = new long[numOfFeatures];
                for (int i = 0; i < numOfFeatures; ++i) {
                    Row row = rows.get(i);
                    pValues.set(i, ((Double)row.getField(1)).doubleValue());
                    degrees[i] = (Long)row.getField(2);
                    fValues.set(i, ((Double)row.getField(3)).doubleValue());
                }
                return Tuple3.of((Object)pValues, (Object)degrees, (Object)fValues);
            }
        });
        return tEnv.fromDataStream((DataStream)output).as("pValues", new String[]{"degreesOfFreedom", "fValues"});
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static ANOVATest load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (ANOVATest)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class ANOVAAggregator
    implements AggregateFunction<Tuple2<Vector, Double>, Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[], List<Row>> {
        private ANOVAAggregator() {
        }

        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] createAccumulator() {
            return new Tuple3[0];
        }

        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] add(Tuple2<Vector, Double> featuresAndLabel, Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
            int i;
            Vector features = (Vector)featuresAndLabel.f0;
            double label = (Double)featuresAndLabel.f1;
            int numOfFeatures = features.size();
            if (acc.length == 0) {
                acc = new Tuple3[features.size()];
                for (i = 0; i < numOfFeatures; ++i) {
                    acc[i] = Tuple3.of((Object)0.0, (Object)0.0, new HashMap());
                }
            }
            for (i = 0; i < numOfFeatures; ++i) {
                double featureValue = features.get(i);
                Tuple2 tuple2 = acc[i];
                Double.valueOf((Double)tuple2.f0 + featureValue);
                tuple2.f0 = tuple2.f0;
                tuple2 = acc[i];
                Double.valueOf((Double)tuple2.f1 + featureValue * featureValue);
                tuple2.f1 = tuple2.f1;
                if (((HashMap)acc[i].f2).containsKey(label)) {
                    tuple2 = (Tuple2)((HashMap)acc[i].f2).get(label);
                    Double.valueOf((Double)tuple2.f0 + featureValue);
                    tuple2.f0 = tuple2.f0;
                    tuple2 = (Tuple2)((HashMap)acc[i].f2).get(label);
                    Long.valueOf((Long)tuple2.f1 + 1L);
                    tuple2.f1 = tuple2.f1;
                    continue;
                }
                ((HashMap)acc[i].f2).put(label, Tuple2.of((Object)featureValue, (Object)1L));
            }
            return acc;
        }

        public List<Row> getResult(Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
            ArrayList<Row> results = new ArrayList<Row>();
            for (int i = 0; i < acc.length; ++i) {
                Tuple3<Double, Long, Double> resultOfANOVA = this.computeANOVA((Double)acc[i].f0, (Double)acc[i].f1, (HashMap)acc[i].f2);
                results.add(Row.of((Object[])new Object[]{i, resultOfANOVA.f0, resultOfANOVA.f1, resultOfANOVA.f2}));
            }
            return results;
        }

        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] merge(Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc1, Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc2) {
            if (acc1.length == 0) {
                return acc2;
            }
            if (acc2.length == 0) {
                return acc1;
            }
            IntStream.range(0, acc1.length).forEach(i -> {
                Tuple3 tuple3 = acc2[i];
                tuple3.f0 = (Double)tuple3.f0 + (Double)acc1[i].f0;
                tuple3 = acc2[i];
                tuple3.f1 = (Double)tuple3.f1 + (Double)acc1[i].f1;
                ((HashMap)acc1[i].f2).forEach((k, v) -> {
                    if (((HashMap)acc2[i].f2).containsKey(k)) {
                        Tuple2 tuple2 = (Tuple2)((HashMap)acc2[i].f2).get(k);
                        tuple2.f0 = (Double)tuple2.f0 + (Double)v.f0;
                        tuple2 = (Tuple2)((HashMap)acc2[i].f2).get(k);
                        tuple2.f1 = (Long)tuple2.f1 + (Long)v.f1;
                    } else {
                        ((HashMap)acc2[i].f2).put(k, v);
                    }
                });
            });
            return acc2;
        }

        private Tuple3<Double, Long, Double> computeANOVA(double sum, double sumOfSq, HashMap<Double, Tuple2<Double, Long>> summary) {
            long numOfClasses = summary.size();
            long numOfSamples = summary.values().stream().mapToLong(t -> (Long)t.f1).sum();
            double sqSum = sum * sum;
            double ssTot = sumOfSq - sqSum / (double)numOfSamples;
            double totalSqSum = 0.0;
            for (Tuple2<Double, Long> t2 : summary.values()) {
                totalSqSum += (Double)t2.f0 * (Double)t2.f0 / (double)((Long)t2.f1).longValue();
            }
            double sumOfSqBetween = totalSqSum - sqSum / (double)numOfSamples;
            double sumOfSqWithin = ssTot - sumOfSqBetween;
            long degreeOfFreedomBetween = numOfClasses - 1L;
            Preconditions.checkArgument((degreeOfFreedomBetween > 0L ? 1 : 0) != 0, (Object)"Num of classes should be positive.");
            long degreeOfFreedomWithin = numOfSamples - numOfClasses;
            Preconditions.checkArgument((degreeOfFreedomWithin > 0L ? 1 : 0) != 0, (Object)"Num of samples should be greater than num of classes.");
            double meanSqBetween = sumOfSqBetween / (double)degreeOfFreedomBetween;
            double meanSqWithin = sumOfSqWithin / (double)degreeOfFreedomWithin;
            double fValue = meanSqBetween / meanSqWithin;
            FDistribution fd = new FDistribution((double)degreeOfFreedomBetween, (double)degreeOfFreedomWithin);
            double pValue = 1.0 - fd.cumulativeProbability(fValue);
            long degreeOfFreedom = degreeOfFreedomBetween + degreeOfFreedomWithin;
            return Tuple3.of((Object)pValue, (Object)degreeOfFreedom, (Object)fValue);
        }
    }
}

