/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.univariatefeatureselector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModelData;
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.stats.anovatest.ANOVATest;
import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
import org.apache.flink.ml.stats.fvaluetest.FValueTest;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class UnivariateFeatureSelector
implements Estimator<UnivariateFeatureSelector, UnivariateFeatureSelectorModel>,
UnivariateFeatureSelectorParams<UnivariateFeatureSelector> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public UnivariateFeatureSelector() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public UnivariateFeatureSelectorModel fit(Table ... inputs) {
        Table output;
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String featuresCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        String featureType = this.getFeatureType();
        String labelType = this.getLabelType();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        if ("categorical".equals(featureType) && "categorical".equals(labelType)) {
            output = ((ChiSqTest)((ChiSqTest)((ChiSqTest)new ChiSqTest().setFeaturesCol(featuresCol)).setLabelCol(labelCol)).setFlatten(true)).transform(inputs[0])[0];
        } else if ("continuous".equals(featureType) && "categorical".equals(labelType)) {
            output = ((ANOVATest)((ANOVATest)((ANOVATest)new ANOVATest().setFeaturesCol(featuresCol)).setLabelCol(labelCol)).setFlatten(true)).transform(inputs[0])[0];
        } else if ("continuous".equals(featureType) && "continuous".equals(labelType)) {
            output = ((FValueTest)((FValueTest)((FValueTest)new FValueTest().setFeaturesCol(featuresCol)).setLabelCol(labelCol)).setFlatten(true)).transform(inputs[0])[0];
        } else {
            throw new IllegalArgumentException(String.format("Unsupported combination: featureType=%s, labelType=%s.", featureType, labelType));
        }
        SingleOutputStreamOperator modelData = tEnv.toDataStream(output).transform("selectIndicesFromPValues", TypeInformation.of(UnivariateFeatureSelectorModelData.class), (OneInputStreamOperator)new SelectIndicesFromPValuesOperator(this.getSelectionMode(), this.getActualSelectionThreshold())).setParallelism(1);
        UnivariateFeatureSelectorModel model = new UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    private double getActualSelectionThreshold() {
        Double threshold = this.getSelectionThreshold();
        if (threshold == null) {
            String selectionMode = this.getSelectionMode();
            threshold = "numTopFeatures".equals(selectionMode) ? Double.valueOf(50.0) : ("percentile".equals(selectionMode) ? Double.valueOf(0.1) : Double.valueOf(0.05));
        } else if ("numTopFeatures".equals(this.getSelectionMode())) {
            Preconditions.checkArgument((threshold >= 1.0 && (double)threshold.intValue() == threshold ? 1 : 0) != 0, (String)"SelectionThreshold needs to be a positive Integer for selection mode numTopFeatures, but got %s.", (Object[])new Object[]{threshold});
        } else {
            Preconditions.checkArgument((threshold >= 0.0 && threshold <= 1.0 ? 1 : 0) != 0, (String)"SelectionThreshold needs to be in the range [0, 1] for selection mode %s, but got %s.", (Object[])new Object[]{this.getSelectionMode(), threshold});
        }
        return threshold;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static UnivariateFeatureSelector load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (UnivariateFeatureSelector)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class SelectIndicesFromPValuesOperator
    extends AbstractStreamOperator<UnivariateFeatureSelectorModelData>
    implements OneInputStreamOperator<Row, UnivariateFeatureSelectorModelData>,
    BoundedOneInput {
        private final String selectionMode;
        private final double threshold;
        private List<Tuple2<Double, Integer>> pValuesAndIndices;
        private ListState<Tuple2<Double, Integer>> pValuesAndIndicesState;

        public SelectIndicesFromPValuesOperator(String selectionMode, double threshold) {
            this.selectionMode = selectionMode;
            this.threshold = threshold;
        }

        public void endInput() {
            ArrayList indices = new ArrayList();
            switch (this.selectionMode) {
                case "numTopFeatures": {
                    this.pValuesAndIndices.sort(Comparator.comparingDouble(t -> (Double)t.f0).thenComparingInt(t -> (Integer)t.f1));
                    IntStream.range(0, Math.min(this.pValuesAndIndices.size(), (int)this.threshold)).forEach(i -> indices.add((Integer)this.pValuesAndIndices.get((int)i).f1));
                    break;
                }
                case "percentile": {
                    this.pValuesAndIndices.sort(Comparator.comparingDouble(t -> (Double)t.f0).thenComparingInt(t -> (Integer)t.f1));
                    IntStream.range(0, Math.min(this.pValuesAndIndices.size(), (int)((double)this.pValuesAndIndices.size() * this.threshold))).forEach(i -> indices.add((Integer)this.pValuesAndIndices.get((int)i).f1));
                    break;
                }
                case "fpr": {
                    this.pValuesAndIndices.stream().filter(x -> (Double)x.f0 < this.threshold).forEach(x -> indices.add((Integer)x.f1));
                    break;
                }
                case "fdr": {
                    this.pValuesAndIndices.sort(Comparator.comparingDouble(t -> (Double)t.f0).thenComparingInt(t -> (Integer)t.f1));
                    int maxIndex = -1;
                    for (int i2 = 0; i2 < this.pValuesAndIndices.size(); ++i2) {
                        if (!((Double)this.pValuesAndIndices.get((int)i2).f0 < this.threshold / (double)this.pValuesAndIndices.size() * (double)(i2 + 1))) continue;
                        maxIndex = Math.max(maxIndex, i2);
                    }
                    if (maxIndex < 0) break;
                    this.pValuesAndIndices.sort(Comparator.comparingDouble(t -> (Double)t.f0).thenComparingInt(t -> (Integer)t.f1));
                    IntStream.range(0, maxIndex + 1).forEach(i -> indices.add((Integer)this.pValuesAndIndices.get((int)i).f1));
                    break;
                }
                case "fwe": {
                    this.pValuesAndIndices.stream().filter(x -> (Double)x.f0 < this.threshold / (double)this.pValuesAndIndices.size()).forEach(x -> indices.add((Integer)x.f1));
                    break;
                }
                default: {
                    throw new RuntimeException("Unknown Selection Mode: " + this.selectionMode);
                }
            }
            UnivariateFeatureSelectorModelData modelData = new UnivariateFeatureSelectorModelData(indices.stream().mapToInt(Integer::intValue).toArray());
            this.output.collect((Object)new StreamRecord((Object)modelData));
        }

        public void processElement(StreamRecord<Row> record) {
            Row row = (Row)record.getValue();
            double pValue = (Double)row.getField("pValue");
            int featureIndex = (Integer)row.getField("featureIndex");
            this.pValuesAndIndices.add((Tuple2<Double, Integer>)Tuple2.of((Object)pValue, (Object)featureIndex));
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.pValuesAndIndicesState = context.getOperatorStateStore().getListState(new ListStateDescriptor("pValuesAndIndices", Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.INT})));
            this.pValuesAndIndices = IteratorUtils.toList(((Iterable)this.pValuesAndIndicesState.get()).iterator());
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.pValuesAndIndicesState.update(this.pValuesAndIndices);
        }
    }
}

