/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.onehotencoder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class OneHotEncoder
implements Estimator<OneHotEncoder, OneHotEncoderModel>,
OneHotEncoderParams<OneHotEncoder> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public OneHotEncoder() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public OneHotEncoderModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        Preconditions.checkArgument((boolean)this.getHandleInvalid().equals("error"));
        String[] inputCols = this.getInputCols();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator localMaxIndices = tEnv.toDataStream(inputs[0]).transform("ExtractInputValueAndFindMaxIndexOperator", (TypeInformation)ObjectArrayTypeInfo.getInfoFor((TypeInformation)BasicTypeInfo.INT_TYPE_INFO), (OneInputStreamOperator)new ExtractInputValueAndFindMaxIndexOperator(inputCols));
        SingleOutputStreamOperator modelData = localMaxIndices.transform("GenerateModelDataOperator", (TypeInformation)TupleTypeInfo.getBasicTupleTypeInfo((Class[])new Class[]{Integer.class, Integer.class}), (OneInputStreamOperator)new GenerateModelDataOperator(inputCols.length)).setParallelism(1);
        OneHotEncoderModel model = new OneHotEncoderModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static OneHotEncoder load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (OneHotEncoder)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static Integer[] initMaxIndices(int length) {
        Object[] indices = new Integer[length];
        Arrays.fill(indices, (Object)Integer.MIN_VALUE);
        return indices;
    }

    private static class GenerateModelDataOperator
    extends AbstractStreamOperator<Tuple2<Integer, Integer>>
    implements OneInputStreamOperator<Integer[], Tuple2<Integer, Integer>>,
    BoundedOneInput {
        private final int inputColsNum;
        private ListState<Integer[]> maxIndicesState;
        private Integer[] maxIndices;

        private GenerateModelDataOperator(int inputColsNum) {
            this.inputColsNum = inputColsNum;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            ObjectArrayTypeInfo type = ObjectArrayTypeInfo.getInfoFor((TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
            this.maxIndicesState = context.getOperatorStateStore().getListState(new ListStateDescriptor("maxIndices", (TypeInformation)type));
            this.maxIndices = OperatorStateUtils.getUniqueElement(this.maxIndicesState, (String)"maxIndices").orElse(OneHotEncoder.initMaxIndices(this.inputColsNum));
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.maxIndicesState.update(Collections.singletonList(this.maxIndices));
        }

        public void processElement(StreamRecord<Integer[]> streamRecord) {
            Integer[] indices = (Integer[])streamRecord.getValue();
            for (int i = 0; i < this.maxIndices.length; ++i) {
                if (indices[i] <= this.maxIndices[i]) continue;
                this.maxIndices[i] = indices[i];
            }
        }

        public void endInput() {
            for (int i = 0; i < this.maxIndices.length; ++i) {
                this.output.collect((Object)new StreamRecord((Object)Tuple2.of((Object)i, (Object)this.maxIndices[i])));
            }
        }
    }

    private static class ExtractInputValueAndFindMaxIndexOperator
    extends AbstractStreamOperator<Integer[]>
    implements OneInputStreamOperator<Row, Integer[]>,
    BoundedOneInput {
        private final String[] inputCols;
        private ListState<Integer[]> maxIndicesState;
        private Integer[] maxIndices;

        private ExtractInputValueAndFindMaxIndexOperator(String[] inputCols) {
            this.inputCols = inputCols;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            ObjectArrayTypeInfo type = ObjectArrayTypeInfo.getInfoFor((TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
            this.maxIndicesState = context.getOperatorStateStore().getListState(new ListStateDescriptor("maxIndices", (TypeInformation)type));
            this.maxIndices = OperatorStateUtils.getUniqueElement(this.maxIndicesState, (String)"maxIndices").orElse(OneHotEncoder.initMaxIndices(this.inputCols.length));
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.maxIndicesState.update(Collections.singletonList(this.maxIndices));
        }

        public void processElement(StreamRecord<Row> streamRecord) {
            Row row = (Row)streamRecord.getValue();
            for (int i = 0; i < this.inputCols.length; ++i) {
                Number number = (Number)row.getField(this.inputCols[i]);
                int value = number.intValue();
                if ((double)value != number.doubleValue()) {
                    throw new IllegalArgumentException(String.format("Value %s cannot be parsed as indexed integer.", number));
                }
                Preconditions.checkArgument((value >= 0 ? 1 : 0) != 0, (Object)"Negative value not supported.");
                if (value <= this.maxIndices[i]) continue;
                this.maxIndices[i] = value;
            }
        }

        public void endInput() {
            this.output.collect((Object)new StreamRecord((Object)this.maxIndices));
        }
    }
}

