/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.stringindexer;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.stringindexer.StringIndexerModel;
import org.apache.flink.ml.feature.stringindexer.StringIndexerModelData;
import org.apache.flink.ml.feature.stringindexer.StringIndexerParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class StringIndexer
implements Estimator<StringIndexer, StringIndexerModel>,
StringIndexerParams<StringIndexer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public StringIndexer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static StringIndexer load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (StringIndexer)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public StringIndexerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String[] inputCols = this.getInputCols();
        String[] outputCols = this.getOutputCols();
        Preconditions.checkArgument((inputCols.length == outputCols.length ? 1 : 0) != 0);
        if (this.getMaxIndexNum() < Integer.MAX_VALUE) {
            Preconditions.checkArgument((boolean)this.getStringOrderType().equals("frequencyDesc"), (Object)("Setting " + StringIndexer.MAX_INDEX_NUM.name + " smaller than INT.MAX only works when " + StringIndexer.STRING_ORDER_TYPE.name + " is set as " + "frequencyDesc" + "."));
        }
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator localCountedString = tEnv.toDataStream(inputs[0]).transform("countStringOperator", Types.OBJECT_ARRAY((TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.LONG)), (OneInputStreamOperator)new CountStringOperator(inputCols));
        DataStream countedString = DataStreamUtils.reduce((DataStream)localCountedString, (ReduceFunction & Serializable)(value1, value2) -> {
            for (int i = 0; i < ((Map[])value1).length; ++i) {
                for (Map.Entry stringAndCnt : value2[i].entrySet()) {
                    value1[i].compute((String)stringAndCnt.getKey(), (k, v) -> v == null ? (Long)stringAndCnt.getValue() : v + (Long)stringAndCnt.getValue());
                }
            }
            return value1;
        }, (TypeInformation)Types.OBJECT_ARRAY((TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.LONG)));
        SingleOutputStreamOperator modelData = countedString.map((MapFunction)new ModelGenerator(this.getStringOrderType(), this.getMaxIndexNum()));
        modelData.getTransformation().setParallelism(1);
        StringIndexerModel model = new StringIndexerModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    private static class ModelGenerator
    implements MapFunction<Map<String, Long>[], StringIndexerModelData> {
        private final String stringOrderType;
        private final int maxIndexNum;

        public ModelGenerator(String stringOrderType, int maxIndexNum) {
            this.stringOrderType = stringOrderType;
            this.maxIndexNum = maxIndexNum;
        }

        public StringIndexerModelData map(Map<String, Long>[] value) {
            int numCols = value.length;
            String[][] stringArrays = new String[numCols][];
            ArrayList<Tuple2> stringsAndCnts = new ArrayList<Tuple2>();
            for (int i = 0; i < numCols; ++i) {
                stringsAndCnts.clear();
                stringsAndCnts.ensureCapacity(value[i].size());
                for (Map.Entry<String, Long> entry : value[i].entrySet()) {
                    stringsAndCnts.add(Tuple2.of((Object)entry.getKey(), (Object)entry.getValue()));
                }
                switch (this.stringOrderType) {
                    case "alphabetAsc": {
                        stringsAndCnts.sort(Comparator.comparing(valAndCnt -> (String)valAndCnt.f0));
                        break;
                    }
                    case "alphabetDesc": {
                        stringsAndCnts.sort((valAndCnt1, valAndCnt2) -> -((String)valAndCnt1.f0).compareTo((String)valAndCnt2.f0));
                        break;
                    }
                    case "frequencyAsc": {
                        stringsAndCnts.sort(Comparator.comparing(valAndCnt -> (Long)valAndCnt.f1));
                        break;
                    }
                    case "frequencyDesc": {
                        stringsAndCnts.sort((valAndCnt1, valAndCnt2) -> -((Long)valAndCnt1.f1).compareTo((Long)valAndCnt2.f1));
                        if (stringsAndCnts.size() <= this.maxIndexNum) break;
                        ArrayList<Tuple2> frequentStringsAndCnts = new ArrayList<Tuple2>();
                        frequentStringsAndCnts.ensureCapacity(this.maxIndexNum - 1);
                        for (int indexId = 0; indexId < this.maxIndexNum - 1; ++indexId) {
                            frequentStringsAndCnts.add((Tuple2)stringsAndCnts.get(indexId));
                        }
                        stringsAndCnts = frequentStringsAndCnts;
                        break;
                    }
                    case "arbitrary": {
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported " + StringIndexerParams.STRING_ORDER_TYPE + " type: " + this.stringOrderType + ".");
                    }
                }
                stringArrays[i] = (String[])stringsAndCnts.stream().map((? super T x) -> (String)x.f0).toArray(String[]::new);
            }
            return new StringIndexerModelData(stringArrays);
        }
    }

    private static class CountStringOperator
    extends AbstractStreamOperator<Map<String, Long>[]>
    implements OneInputStreamOperator<Row, Map<String, Long>[]>,
    BoundedOneInput {
        private final String[] inputCols;
        private Map<String, Long>[] stringCntByColumn;
        private ListState<Map<String, Long>[]> stringCntByColumnState;

        public CountStringOperator(String[] inputCols) {
            this.inputCols = inputCols;
            this.stringCntByColumn = new HashMap[inputCols.length];
            for (int i = 0; i < this.stringCntByColumn.length; ++i) {
                this.stringCntByColumn[i] = new HashMap<String, Long>();
            }
        }

        public void endInput() {
            this.output.collect((Object)new StreamRecord(this.stringCntByColumn));
            this.stringCntByColumnState.clear();
        }

        public void processElement(StreamRecord<Row> element) {
            Row r = (Row)element.getValue();
            for (int i = 0; i < this.inputCols.length; ++i) {
                String stringVal;
                Object objVal = r.getField(this.inputCols[i]);
                if (null == objVal) continue;
                if (objVal instanceof String) {
                    stringVal = (String)objVal;
                } else if (objVal instanceof Number) {
                    stringVal = String.valueOf(objVal);
                } else {
                    throw new RuntimeException("The input column only supports string and numeric type.");
                }
                this.stringCntByColumn[i].compute(stringVal, (k, v) -> v == null ? 1L : v + 1L);
            }
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.stringCntByColumnState = context.getOperatorStateStore().getListState(new ListStateDescriptor("stringCntByColumnState", Types.OBJECT_ARRAY((TypeInformation)Types.MAP((TypeInformation)Types.STRING, (TypeInformation)Types.LONG))));
            OperatorStateUtils.getUniqueElement(this.stringCntByColumnState, (String)"stringCntByColumnState").ifPresent(x -> {
                this.stringCntByColumn = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.stringCntByColumnState.update(Collections.singletonList(this.stringCntByColumn));
        }
    }
}

