/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.stats.fvaluetest;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.stats.fvaluetest.FValueTestParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class FValueTest
implements AlgoOperator<FValueTest>,
FValueTestParams<FValueTest> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public FValueTest() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String featuresCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        String broadcastSummaryKey = "broadcastSummaryKey";
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)row -> {
            Number number = (Number)row.getField(labelCol);
            Preconditions.checkNotNull((Object)number, (String)"Input data must contain label value.");
            return new Tuple2((Object)((Vector)row.getField(featuresCol)), (Object)number.doubleValue());
        }).returns(Types.TUPLE((TypeInformation[])new TypeInformation[]{VectorTypeInfo.INSTANCE, Types.DOUBLE}));
        DataStream summaries = DataStreamUtils.aggregate((DataStream)inputData, (AggregateFunction)new SummaryAggregator());
        DataStream covarianceInEachPartition = BroadcastUtils.withBroadcastStream(Collections.singletonList(inputData), Collections.singletonMap("broadcastSummaryKey", summaries), inputList -> {
            DataStream input = (DataStream)inputList.get(0);
            return DataStreamUtils.mapPartition((DataStream)input, (MapPartitionFunction)new CalCovarianceOperator("broadcastSummaryKey"));
        });
        DataStream reducedCovariance = DataStreamUtils.reduce((DataStream)covarianceInEachPartition, (ReduceFunction & Serializable)(sums1, sums2) -> {
            BLAS.axpy((double)1.0, (Vector)sums1, (DenseVector)sums2);
            return sums2;
        });
        DataStream result = BroadcastUtils.withBroadcastStream(Collections.singletonList(reducedCovariance), Collections.singletonMap("broadcastSummaryKey", summaries), inputList -> {
            DataStream input = (DataStream)inputList.get(0);
            return DataStreamUtils.mapPartition((DataStream)input, (MapPartitionFunction)new CalFValueOperator("broadcastSummaryKey"));
        });
        return new Table[]{this.convertToTable(tEnv, (DataStream<Tuple4<Integer, Double, Long, Double>>)result, this.getFlatten())};
    }

    private Table convertToTable(StreamTableEnvironment tEnv, DataStream<Tuple4<Integer, Double, Long, Double>> dataStream, boolean flatten) {
        if (flatten) {
            return tEnv.fromDataStream(dataStream).as("featureIndex", new String[]{"pValue", "degreeOfFreedom", "fValue"});
        }
        DataStream output = DataStreamUtils.mapPartition(dataStream, (MapPartitionFunction)new MapPartitionFunction<Tuple4<Integer, Double, Long, Double>, Tuple3<DenseVector, long[], DenseVector>>(){

            public void mapPartition(Iterable<Tuple4<Integer, Double, Long, Double>> iterable, Collector<Tuple3<DenseVector, long[], DenseVector>> collector) {
                List rows = IteratorUtils.toList(iterable.iterator());
                int numOfFeatures = rows.size();
                DenseVector pValues = new DenseVector(numOfFeatures);
                long[] degrees = new long[numOfFeatures];
                DenseVector fValues = new DenseVector(numOfFeatures);
                for (int i = 0; i < numOfFeatures; ++i) {
                    Tuple4 tuple = (Tuple4)rows.get(i);
                    pValues.set(i, ((Double)tuple.f1).doubleValue());
                    degrees[i] = (Long)tuple.f2;
                    fValues.set(i, ((Double)tuple.f3).doubleValue());
                }
                collector.collect((Object)Tuple3.of((Object)pValues, (Object)degrees, (Object)fValues));
            }
        });
        return tEnv.fromDataStream(output).as("pValues", new String[]{"degreesOfFreedom", "fValues"});
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static FValueTest load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (FValueTest)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class SummaryAggregator
    implements AggregateFunction<Tuple2<Vector, Double>, Tuple5<Long, Double, Double, DenseVector, DenseVector>, Tuple5<Long, Double, Double, DenseVector, DenseVector>> {
        private SummaryAggregator() {
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> createAccumulator() {
            return Tuple5.of((Object)0L, (Object)0.0, (Object)0.0, (Object)new DenseVector(new double[0]), (Object)new DenseVector(new double[0]));
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> add(Tuple2<Vector, Double> featuresAndLabel, Tuple5<Long, Double, Double, DenseVector, DenseVector> summary) {
            Vector features = (Vector)featuresAndLabel.f0;
            double label = (Double)featuresAndLabel.f1;
            if ((Long)summary.f0 == 0L) {
                summary.f3 = new DenseVector(features.size());
                summary.f4 = new DenseVector(features.size());
            }
            Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple5 = summary;
            tuple5.f0 = (Long)tuple5.f0 + 1L;
            tuple5 = summary;
            tuple5.f1 = (Double)tuple5.f1 + label;
            tuple5 = summary;
            tuple5.f2 = (Double)tuple5.f2 + label * label;
            BLAS.axpy((double)1.0, (Vector)features, (DenseVector)((DenseVector)summary.f3));
            for (int i = 0; i < features.size(); ++i) {
                int n = i;
                ((DenseVector)summary.f4).values[n] = ((DenseVector)summary.f4).values[n] + features.get(i) * features.get(i);
            }
            return summary;
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> getResult(Tuple5<Long, Double, Double, DenseVector, DenseVector> summary) {
            long numRows = (Long)summary.f0;
            Preconditions.checkState((numRows > 0L ? 1 : 0) != 0, (Object)"The training set is empty.");
            int numOfFeatures = ((DenseVector)summary.f3).size();
            double labelMean = (Double)summary.f1 / (double)numRows;
            Tuple5 result = Tuple5.of((Object)numRows, (Object)labelMean, (Object)Math.sqrt(((Double)summary.f2 / (double)numRows - labelMean * labelMean) * (double)numRows / (double)(numRows - 1L)), (Object)new DenseVector(numOfFeatures), (Object)new DenseVector(numOfFeatures));
            for (int i = 0; i < ((DenseVector)summary.f3).size(); ++i) {
                double mean;
                ((DenseVector)result.f3).values[i] = mean = ((DenseVector)summary.f3).get(i) / (double)numRows;
                ((DenseVector)result.f4).values[i] = Math.sqrt((((DenseVector)summary.f4).get(i) / (double)numRows - mean * mean) * (double)numRows / (double)(numRows - 1L));
            }
            return result;
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> merge(Tuple5<Long, Double, Double, DenseVector, DenseVector> summary1, Tuple5<Long, Double, Double, DenseVector, DenseVector> summary2) {
            if ((Long)summary1.f0 == 0L) {
                return summary2;
            }
            if ((Long)summary2.f0 == 0L) {
                return summary1;
            }
            Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple5 = summary2;
            tuple5.f0 = (Long)tuple5.f0 + (Long)summary1.f0;
            tuple5 = summary2;
            tuple5.f1 = (Double)tuple5.f1 + (Double)summary1.f1;
            tuple5 = summary2;
            tuple5.f2 = (Double)tuple5.f2 + (Double)summary1.f2;
            BLAS.axpy((double)1.0, (Vector)((Vector)summary1.f3), (DenseVector)((DenseVector)summary2.f3));
            BLAS.axpy((double)1.0, (Vector)((Vector)summary1.f4), (DenseVector)((DenseVector)summary2.f4));
            return summary2;
        }
    }

    private static class CalFValueOperator
    extends RichMapPartitionFunction<DenseVector, Tuple4<Integer, Double, Long, Double>> {
        private final String broadcastKey;
        private DenseVector sumVector;

        private CalFValueOperator(String broadcastKey) {
            this.broadcastKey = broadcastKey;
        }

        public void mapPartition(Iterable<DenseVector> iterable, Collector<Tuple4<Integer, Double, Long, Double>> collector) {
            Tuple5 summaries = (Tuple5)this.getRuntimeContext().getBroadcastVariable(this.broadcastKey).get(0);
            int expectedNumOfFeatures = ((DenseVector)summaries.f4).size();
            if (iterable.iterator().hasNext()) {
                this.sumVector = iterable.iterator().next();
            }
            Preconditions.checkArgument((this.sumVector.size() == expectedNumOfFeatures ? 1 : 0) != 0, (String)"Input %s features, but FValueTest is expecting %s features.", (Object[])new Object[]{this.sumVector.size(), expectedNumOfFeatures});
            long numSamples = (Long)summaries.f0;
            long degreesOfFreedom = numSamples - 2L;
            FDistribution fDistribution = new FDistribution(1.0, (double)degreesOfFreedom);
            for (int i = 0; i < expectedNumOfFeatures; ++i) {
                double covariance = this.sumVector.get(i);
                double corr = covariance / ((Double)summaries.f2 * ((DenseVector)summaries.f4).get(i));
                double fValue = corr * corr / (1.0 - corr * corr) * (double)degreesOfFreedom;
                double pValue = 1.0 - fDistribution.cumulativeProbability(fValue);
                collector.collect((Object)Tuple4.of((Object)i, (Object)pValue, (Object)degreesOfFreedom, (Object)fValue));
            }
        }
    }

    private static class CalCovarianceOperator
    extends RichMapPartitionFunction<Tuple2<Vector, Double>, DenseVector> {
        private final String broadcastKey;

        private CalCovarianceOperator(String broadcastKey) {
            this.broadcastKey = broadcastKey;
        }

        public void mapPartition(Iterable<Tuple2<Vector, Double>> iterable, Collector<DenseVector> collector) {
            Tuple5 summaries = (Tuple5)this.getRuntimeContext().getBroadcastVariable(this.broadcastKey).get(0);
            int expectedNumOfFeatures = ((DenseVector)summaries.f3).size();
            DenseVector sumVector = new DenseVector(expectedNumOfFeatures);
            for (Tuple2<Vector, Double> featuresAndLabel : iterable) {
                Preconditions.checkArgument((((Vector)featuresAndLabel.f0).size() == expectedNumOfFeatures ? 1 : 0) != 0, (String)"Input %s features, but FValueTest is expecting %s features.", (Object[])new Object[]{((Vector)featuresAndLabel.f0).size(), expectedNumOfFeatures});
                double yDiff = (Double)featuresAndLabel.f1 - (Double)summaries.f1;
                if (yDiff == 0.0) continue;
                for (int i = 0; i < expectedNumOfFeatures; ++i) {
                    int n = i;
                    sumVector.values[n] = sumVector.values[n] + yDiff * (((Vector)featuresAndLabel.f0).get(i) - ((DenseVector)summaries.f3).get(i));
                }
            }
            BLAS.scal((double)(1.0 / (double)((Long)summaries.f0 - 1L)), (Vector)sumVector);
            collector.collect((Object)sumVector);
        }
    }
}

