/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.stats.chisqtest;

import java.io.IOException;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.stats.chisqtest.ChiSqTestParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class ChiSqTest
implements AlgoOperator<ChiSqTest>,
ChiSqTestParams<ChiSqTest> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public ChiSqTest() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
        String bcLabelMarginsKey = "bcLabelMarginsKey";
        String featuresCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator indexAndFeatureAndLabel = tEnv.toDataStream(inputs[0]).flatMap((FlatMapFunction)new ExtractIndexAndFeatureAndLabel(featuresCol, labelCol));
        SingleOutputStreamOperator observedFreq = indexAndFeatureAndLabel.keyBy(Tuple3::hashCode).transform("GenerateObservedFrequencies", Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.DOUBLE, Types.LONG}), (OneInputStreamOperator)new GenerateObservedFrequencies());
        SingleOutputStreamOperator filledObservedFreq = observedFreq.transform("filledObservedFreq", Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.DOUBLE, Types.LONG}), (OneInputStreamOperator)new FillFrequencyTable()).setParallelism(1);
        SingleOutputStreamOperator categoricalMargins = observedFreq.keyBy((KeySelector & Serializable)tuple -> new Tuple2((Object)((Integer)tuple.f0), (Object)((Double)tuple.f1)).hashCode()).transform("AggregateCategoricalMargins", Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.LONG}), (OneInputStreamOperator)new AggregateCategoricalMargins());
        SingleOutputStreamOperator labelMargins = observedFreq.keyBy((KeySelector & Serializable)tuple -> new Tuple2((Object)((Integer)tuple.f0), (Object)((Double)tuple.f2)).hashCode()).transform("AggregateLabelMargins", Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.LONG}), (OneInputStreamOperator)new AggregateLabelMargins());
        Function<List, DataStream> function = dataStreams -> {
            DataStream stream = (DataStream)dataStreams.get(0);
            return stream.map((MapFunction)new ChiSqFunc("bcCategoricalMarginsKey", "bcLabelMarginsKey"), Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.INT}));
        };
        HashMap bcMap = new HashMap<String, DataStream<?>>((DataStream)categoricalMargins, (DataStream)labelMargins){
            final /* synthetic */ DataStream val$categoricalMargins;
            final /* synthetic */ DataStream val$labelMargins;
            {
                this.val$categoricalMargins = dataStream;
                this.val$labelMargins = dataStream2;
                this.put("bcCategoricalMarginsKey", this.val$categoricalMargins);
                this.put("bcLabelMarginsKey", this.val$labelMargins);
            }
        };
        DataStream categoricalStatistics = BroadcastUtils.withBroadcastStream(Collections.singletonList(filledObservedFreq), (Map)bcMap, function);
        boolean flatten = this.getFlatten();
        RowTypeInfo outputTypeInfo = flatten ? new RowTypeInfo(new TypeInformation[]{Types.INT, Types.DOUBLE, Types.INT, Types.DOUBLE}, new String[]{"featureIndex", "pValue", "degreeOfFreedom", "statistic"}) : new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.PRIMITIVE_ARRAY((TypeInformation)Types.INT), DenseVectorTypeInfo.INSTANCE}, new String[]{"pValues", "degreesOfFreedom", "statistics"});
        SingleOutputStreamOperator chiSqTestResult = categoricalStatistics.transform("chiSqTestResult", (TypeInformation)outputTypeInfo, (OneInputStreamOperator)new AggregateChiSqFunc(flatten)).setParallelism(1);
        return new Table[]{tEnv.fromDataStream((DataStream)chiSqTestResult)};
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (ChiSqTest)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class AggregateChiSqFunc
    extends AbstractStreamOperator<Row>
    implements OneInputStreamOperator<Tuple3<Integer, Double, Integer>, Row>,
    BoundedOneInput {
        private final boolean flatten;
        private Map<Integer, Tuple2<Double, Integer>> index2Statistic = new HashMap<Integer, Tuple2<Double, Integer>>();
        private ListState<Map<Integer, Tuple2<Double, Integer>>> index2StatisticState;

        private AggregateChiSqFunc(boolean flatten) {
            this.flatten = flatten;
        }

        public void endInput() {
            if (this.flatten) {
                this.endInputWithFlatten();
            } else {
                this.endInputWithoutFlatten();
            }
        }

        private void endInputWithFlatten() {
            for (Map.Entry<Integer, Tuple2<Double, Integer>> entry : this.index2Statistic.entrySet()) {
                int index = entry.getKey();
                Tuple3<Double, Integer, Double> pValueAndDofAndStatistic = AggregateChiSqFunc.computePValueAndScale(entry.getValue());
                this.output.collect((Object)new StreamRecord((Object)Row.of((Object[])new Object[]{index, pValueAndDofAndStatistic.f0, pValueAndDofAndStatistic.f1, pValueAndDofAndStatistic.f2})));
            }
        }

        private void endInputWithoutFlatten() {
            int size = this.index2Statistic.size();
            DenseVector pValueScaledVector = new DenseVector(size);
            DenseVector statisticScaledVector = new DenseVector(size);
            int[] dofArray = new int[size];
            for (Map.Entry<Integer, Tuple2<Double, Integer>> entry : this.index2Statistic.entrySet()) {
                int index = entry.getKey();
                Tuple3<Double, Integer, Double> pValueAndDofAndStatistic = AggregateChiSqFunc.computePValueAndScale(entry.getValue());
                pValueScaledVector.set(index, ((Double)pValueAndDofAndStatistic.f0).doubleValue());
                statisticScaledVector.set(index, ((Double)pValueAndDofAndStatistic.f2).doubleValue());
                dofArray[index] = (Integer)pValueAndDofAndStatistic.f1;
            }
            this.output.collect((Object)new StreamRecord((Object)Row.of((Object[])new Object[]{pValueScaledVector, dofArray, statisticScaledVector})));
        }

        private static Tuple3<Double, Integer, Double> computePValueAndScale(Tuple2<Double, Integer> statisticAndDof) {
            Double statistic = (Double)statisticAndDof.f0;
            Integer dof = (Integer)statisticAndDof.f1;
            double pValue = 1.0;
            if (dof == 0) {
                statistic = 0.0;
            } else {
                pValue = 1.0 - new ChiSquaredDistribution((double)dof.intValue()).cumulativeProbability(statistic.doubleValue());
            }
            double pValueScaled = new BigDecimal(pValue).setScale(11, RoundingMode.HALF_UP).doubleValue();
            double statisticScaled = new BigDecimal(statistic).setScale(11, RoundingMode.HALF_UP).doubleValue();
            return Tuple3.of((Object)pValueScaled, (Object)dof, (Object)statisticScaled);
        }

        public void processElement(StreamRecord<Tuple3<Integer, Double, Integer>> element) {
            Tuple3 indexAndStatisticAndDof = (Tuple3)element.getValue();
            Integer index = (Integer)indexAndStatisticAndDof.f0;
            Double partialStatistic = (Double)indexAndStatisticAndDof.f1;
            Integer dof = (Integer)indexAndStatisticAndDof.f2;
            this.index2Statistic.merge(index, (Tuple2<Double, Integer>)new Tuple2((Object)partialStatistic, (Object)dof), (thisOne, otherOne) -> {
                Tuple2 tuple2 = thisOne;
                tuple2.f0 = (Double)tuple2.f0 + (Double)otherOne.f0;
                return thisOne;
            });
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.index2StatisticState = context.getOperatorStateStore().getListState(new ListStateDescriptor("index2StatisticState", Types.MAP((TypeInformation)Types.INT, (TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.INT}))));
            OperatorStateUtils.getUniqueElement(this.index2StatisticState, (String)"index2StatisticState").ifPresent(x -> {
                this.index2Statistic = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.index2StatisticState.update(Collections.singletonList(this.index2Statistic));
        }
    }

    private static class ChiSqFunc
    extends RichMapFunction<Tuple4<Integer, Double, Double, Long>, Tuple3<Integer, Double, Integer>> {
        private final String bcCategoricalMarginsKey;
        private final String bcLabelMarginsKey;
        private final Map<Tuple2<Integer, Double>, Long> categoricalMargins = new HashMap<Tuple2<Integer, Double>, Long>();
        private final Map<Tuple2<Integer, Double>, Long> labelMargins = new HashMap<Tuple2<Integer, Double>, Long>();
        double sampleSize = 0.0;
        int numLabels = 0;
        HashMap<Integer, Integer> index2NumCategories = new HashMap();

        public ChiSqFunc(String bcCategoricalMarginsKey, String bcLabelMarginsKey) {
            this.bcCategoricalMarginsKey = bcCategoricalMarginsKey;
            this.bcLabelMarginsKey = bcLabelMarginsKey;
        }

        public Tuple3<Integer, Double, Integer> map(Tuple4<Integer, Double, Double, Long> v) {
            if (this.categoricalMargins.isEmpty()) {
                List categoricalMarginList = this.getRuntimeContext().getBroadcastVariable(this.bcCategoricalMarginsKey);
                List labelMarginList = this.getRuntimeContext().getBroadcastVariable(this.bcLabelMarginsKey);
                for (Tuple3 indexAndFeatureAndCount : categoricalMarginList) {
                    this.index2NumCategories.merge((Integer)indexAndFeatureAndCount.f0, 1, Integer::sum);
                }
                this.numLabels = (int)labelMarginList.stream().map((? super T x) -> (Double)x.f1).distinct().count();
                for (Tuple3 indexAndFeatureAndCount : categoricalMarginList) {
                    this.categoricalMargins.put((Tuple2<Integer, Double>)new Tuple2((Object)((Integer)indexAndFeatureAndCount.f0), (Object)((Double)indexAndFeatureAndCount.f1)), (Long)indexAndFeatureAndCount.f2);
                }
                HashMap<Integer, Double> sampleSizeCount = new HashMap<Integer, Double>();
                Integer tmpKey = null;
                for (Tuple3 indexAndLabelAndCount : labelMarginList) {
                    Integer index = (Integer)indexAndLabelAndCount.f0;
                    if (tmpKey == null) {
                        tmpKey = index;
                        sampleSizeCount.put(index, 0.0);
                    }
                    sampleSizeCount.computeIfPresent(index, (k, count) -> count + (double)((Long)indexAndLabelAndCount.f2).longValue());
                    this.labelMargins.put((Tuple2<Integer, Double>)new Tuple2((Object)index, (Object)((Double)indexAndLabelAndCount.f1)), (Long)indexAndLabelAndCount.f2);
                }
                Optional sampleSizeOpt = sampleSizeCount.values().stream().reduce(Double::sum);
                Preconditions.checkArgument((boolean)sampleSizeOpt.isPresent());
                this.sampleSize = (Double)sampleSizeOpt.get();
            }
            Integer index = (Integer)v.f0;
            int dof = (this.index2NumCategories.get(index) - 1) * (this.numLabels - 1);
            Tuple2 category = new Tuple2((Object)((Integer)v.f0), (Object)((Double)v.f1));
            Tuple2 indexAndLabelKey = new Tuple2((Object)((Integer)v.f0), (Object)((Double)v.f2));
            Long theCategoricalMargin = this.categoricalMargins.get(category);
            Long theLabelMargin = this.labelMargins.get(indexAndLabelKey);
            Long observed = (Long)v.f3;
            double expected = (double)(theLabelMargin * theCategoricalMargin) / this.sampleSize;
            double categoricalStatistic = this.pearsonFunc(observed.longValue(), expected);
            return new Tuple3((Object)index, (Object)categoricalStatistic, (Object)dof);
        }

        private double pearsonFunc(double observed, double expected) {
            double dev = observed - expected;
            return dev * dev / expected;
        }
    }

    private static class AggregateLabelMargins
    extends AbstractStreamOperator<Tuple3<Integer, Double, Long>>
    implements OneInputStreamOperator<Tuple4<Integer, Double, Double, Long>, Tuple3<Integer, Double, Long>>,
    BoundedOneInput {
        private Map<Tuple2<Integer, Double>, Long> labelMarginsMap = new HashMap<Tuple2<Integer, Double>, Long>();
        private ListState<Map<Tuple2<Integer, Double>, Long>> labelMarginsMapState;

        private AggregateLabelMargins() {
        }

        public void endInput() {
            for (Tuple2<Integer, Double> key : this.labelMarginsMap.keySet()) {
                Long labelMargin = this.labelMarginsMap.get(key);
                this.output.collect((Object)new StreamRecord((Object)new Tuple3((Object)((Integer)key.f0), (Object)((Double)key.f1), (Object)labelMargin)));
            }
            this.labelMarginsMapState.clear();
        }

        public void processElement(StreamRecord<Tuple4<Integer, Double, Double, Long>> element) {
            Tuple4 indexAndFeatureAndLabelAndCnt = (Tuple4)element.getValue();
            Long observedFreq = (Long)indexAndFeatureAndLabelAndCnt.f3;
            Tuple2 key = new Tuple2((Object)((Integer)indexAndFeatureAndLabelAndCnt.f0), (Object)((Double)indexAndFeatureAndLabelAndCnt.f2));
            this.labelMarginsMap.compute((Tuple2<Integer, Double>)key, (k, v) -> v == null ? observedFreq : v + observedFreq);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.labelMarginsMapState = context.getOperatorStateStore().getListState(new ListStateDescriptor("labelMarginsMapState", Types.MAP((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE}), (TypeInformation)Types.LONG)));
            OperatorStateUtils.getUniqueElement(this.labelMarginsMapState, (String)"labelMarginsMapState").ifPresent(x -> {
                this.labelMarginsMap = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.labelMarginsMapState.update(Collections.singletonList(this.labelMarginsMap));
        }
    }

    private static class AggregateCategoricalMargins
    extends AbstractStreamOperator<Tuple3<Integer, Double, Long>>
    implements OneInputStreamOperator<Tuple4<Integer, Double, Double, Long>, Tuple3<Integer, Double, Long>>,
    BoundedOneInput {
        private Map<Tuple2<Integer, Double>, Long> categoricalMarginsMap = new HashMap<Tuple2<Integer, Double>, Long>();
        private ListState<Map<Tuple2<Integer, Double>, Long>> categoricalMarginsMapState;

        private AggregateCategoricalMargins() {
        }

        public void endInput() {
            for (Tuple2<Integer, Double> key : this.categoricalMarginsMap.keySet()) {
                Long categoricalMargin = this.categoricalMarginsMap.get(key);
                this.output.collect((Object)new StreamRecord((Object)new Tuple3((Object)((Integer)key.f0), (Object)((Double)key.f1), (Object)categoricalMargin)));
            }
            this.categoricalMarginsMap.clear();
        }

        public void processElement(StreamRecord<Tuple4<Integer, Double, Double, Long>> element) {
            Tuple4 indexAndCategoryAndLabelAndCnt = (Tuple4)element.getValue();
            Tuple2 key = new Tuple2((Object)((Integer)indexAndCategoryAndLabelAndCnt.f0), (Object)((Double)indexAndCategoryAndLabelAndCnt.f1));
            Long observedFreq = (Long)indexAndCategoryAndLabelAndCnt.f3;
            this.categoricalMarginsMap.compute((Tuple2<Integer, Double>)key, (k, v) -> v == null ? observedFreq : v + observedFreq);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.categoricalMarginsMapState = context.getOperatorStateStore().getListState(new ListStateDescriptor("categoricalMarginsMapState", Types.MAP((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE}), (TypeInformation)Types.LONG)));
            OperatorStateUtils.getUniqueElement(this.categoricalMarginsMapState, (String)"categoricalMarginsMapState").ifPresent(x -> {
                this.categoricalMarginsMap = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.categoricalMarginsMapState.update(Collections.singletonList(this.categoricalMarginsMap));
        }
    }

    private static class FillFrequencyTable
    extends AbstractStreamOperator<Tuple4<Integer, Double, Double, Long>>
    implements OneInputStreamOperator<Tuple4<Integer, Double, Double, Long>, Tuple4<Integer, Double, Double, Long>>,
    BoundedOneInput {
        private Map<Tuple2<Integer, Double>, List<Tuple2<Double, Long>>> valuesMap = new HashMap<Tuple2<Integer, Double>, List<Tuple2<Double, Long>>>();
        private HashSet<Double> distinctLabels = new HashSet();
        private ListState<Map<Tuple2<Integer, Double>, List<Tuple2<Double, Long>>>> valuesMapState;
        private ListState<List<Double>> distinctLabelsState;

        private FillFrequencyTable() {
        }

        public void endInput() {
            for (Map.Entry<Tuple2<Integer, Double>, List<Tuple2<Double, Long>>> entry : this.valuesMap.entrySet()) {
                List<Tuple2<Double, Long>> labelAndCountList = entry.getValue();
                Tuple2<Integer, Double> categoricalKey = entry.getKey();
                List existingLabels = labelAndCountList.stream().map(v -> (Double)v.f0).collect(Collectors.toList());
                for (Double d : this.distinctLabels) {
                    if (existingLabels.contains(d)) continue;
                    Tuple2 generatedLabelCount = new Tuple2((Object)d, (Object)0L);
                    labelAndCountList.add((Tuple2<Double, Long>)generatedLabelCount);
                }
                for (Tuple2 tuple2 : labelAndCountList) {
                    this.output.collect((Object)new StreamRecord((Object)new Tuple4((Object)((Integer)categoricalKey.f0), (Object)((Double)categoricalKey.f1), (Object)((Double)tuple2.f0), (Object)((Long)tuple2.f1))));
                }
            }
            this.valuesMapState.clear();
            this.distinctLabelsState.clear();
        }

        public void processElement(StreamRecord<Tuple4<Integer, Double, Double, Long>> element) {
            Tuple4 indexAndCategoryAndLabelAndCount = (Tuple4)element.getValue();
            Tuple2 key = new Tuple2((Object)((Integer)indexAndCategoryAndLabelAndCount.f0), (Object)((Double)indexAndCategoryAndLabelAndCount.f1));
            Tuple2 labelAndCount = new Tuple2((Object)((Double)indexAndCategoryAndLabelAndCount.f2), (Object)((Long)indexAndCategoryAndLabelAndCount.f3));
            List<Tuple2<Double, Long>> labelAndCountList = this.valuesMap.get(key);
            if (labelAndCountList == null) {
                ArrayList<Tuple2> value = new ArrayList<Tuple2>();
                value.add(labelAndCount);
                this.valuesMap.put((Tuple2<Integer, Double>)key, value);
            } else {
                labelAndCountList.add((Tuple2<Double, Long>)labelAndCount);
            }
            this.distinctLabels.add((Double)indexAndCategoryAndLabelAndCount.f2);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.valuesMapState = context.getOperatorStateStore().getListState(new ListStateDescriptor("valuesMapState", Types.MAP((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE}), (TypeInformation)Types.LIST((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.DOUBLE, Types.LONG})))));
            this.distinctLabelsState = context.getOperatorStateStore().getListState(new ListStateDescriptor("distinctLabelsState", Types.LIST((TypeInformation)Types.DOUBLE)));
            OperatorStateUtils.getUniqueElement(this.valuesMapState, (String)"valuesMapState").ifPresent(x -> {
                this.valuesMap = x;
            });
            OperatorStateUtils.getUniqueElement(this.distinctLabelsState, (String)"distinctLabelsState").ifPresent(x -> {
                this.distinctLabels = new HashSet(x);
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.valuesMapState.update(Collections.singletonList(this.valuesMap));
            this.distinctLabelsState.update(Collections.singletonList(new ArrayList<Double>(this.distinctLabels)));
        }
    }

    private static class GenerateObservedFrequencies
    extends AbstractStreamOperator<Tuple4<Integer, Double, Double, Long>>
    implements OneInputStreamOperator<Tuple3<Integer, Double, Double>, Tuple4<Integer, Double, Double, Long>>,
    BoundedOneInput {
        private Map<Tuple3<Integer, Double, Double>, Long> cntMap = new HashMap<Tuple3<Integer, Double, Double>, Long>();
        private ListState<Map<Tuple3<Integer, Double, Double>, Long>> cntMapState;

        private GenerateObservedFrequencies() {
        }

        public void endInput() {
            for (Tuple3<Integer, Double, Double> key : this.cntMap.keySet()) {
                Long count = this.cntMap.get(key);
                this.output.collect((Object)new StreamRecord((Object)new Tuple4((Object)((Integer)key.f0), (Object)((Double)key.f1), (Object)((Double)key.f2), (Object)count)));
            }
            this.cntMapState.clear();
        }

        public void processElement(StreamRecord<Tuple3<Integer, Double, Double>> element) {
            Tuple3 indexAndCategoryAndLabel = (Tuple3)element.getValue();
            this.cntMap.compute((Tuple3<Integer, Double, Double>)indexAndCategoryAndLabel, (k, v) -> v == null ? 1L : v + 1L);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.cntMapState = context.getOperatorStateStore().getListState(new ListStateDescriptor("cntMapState", Types.MAP((TypeInformation)Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.INT, Types.DOUBLE, Types.DOUBLE}), (TypeInformation)Types.LONG)));
            OperatorStateUtils.getUniqueElement(this.cntMapState, (String)"cntMapState").ifPresent(x -> {
                this.cntMap = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.cntMapState.update(Collections.singletonList(this.cntMap));
        }
    }

    private static class ExtractIndexAndFeatureAndLabel
    extends RichFlatMapFunction<Row, Tuple3<Integer, Double, Double>> {
        private final String featuresCol;
        private final String labelCol;

        public ExtractIndexAndFeatureAndLabel(String featuresCol, String labelCol) {
            this.featuresCol = featuresCol;
            this.labelCol = labelCol;
        }

        public void flatMap(Row row, Collector<Tuple3<Integer, Double, Double>> collector) {
            Double label = ((Number)row.getFieldAs(this.labelCol)).doubleValue();
            Vector features = (Vector)row.getFieldAs(this.featuresCol);
            for (int i = 0; i < features.size(); ++i) {
                collector.collect((Object)Tuple3.of((Object)i, (Object)features.get(i), (Object)label));
            }
        }
    }
}

