/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.randomsplitter;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.feature.randomsplitter.RandomSplitterParams;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SideOutputDataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

public class RandomSplitter
implements AlgoOperator<RandomSplitter>,
RandomSplitterParams<RandomSplitter> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public RandomSplitter() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        RowTypeInfo outputTypeInfo = TableUtils.getRowTypeInfo((ResolvedSchema)inputs[0].getResolvedSchema());
        Double[] weights = this.getWeights();
        OutputTag[] outputTags = new OutputTag[weights.length - 1];
        for (int i = 0; i < outputTags.length; ++i) {
            outputTags[i] = new OutputTag<Row>("outputTag_" + i, (TypeInformation)outputTypeInfo){};
        }
        long seed = this.getSeed();
        SingleOutputStreamOperator results = tEnv.toDataStream(inputs[0]).transform("SplitterOperator", (TypeInformation)outputTypeInfo, (OneInputStreamOperator)new SplitterOperator(outputTags, weights, seed));
        Table[] outputTables = new Table[weights.length];
        outputTables[0] = tEnv.fromDataStream((DataStream)results);
        for (int i = 0; i < outputTags.length; ++i) {
            SideOutputDataStream dataStream = results.getSideOutput(outputTags[i]);
            outputTables[i + 1] = tEnv.fromDataStream((DataStream)dataStream);
        }
        return outputTables;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static RandomSplitter load(StreamTableEnvironment env, String path) throws IOException {
        return (RandomSplitter)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class SplitterOperator
    extends AbstractStreamOperator<Row>
    implements OneInputStreamOperator<Row, Row> {
        private Random random;
        private final long initSeed;
        OutputTag<Row>[] outputTag;
        final double[] fractions;

        public SplitterOperator(OutputTag<Row>[] outputTag, Double[] weights, long initSeed) {
            this.initSeed = initSeed;
            this.outputTag = outputTag;
            this.fractions = new double[weights.length];
            double weightSum = 0.0;
            for (Double weight : weights) {
                weightSum += weight.doubleValue();
            }
            double currentSum = 0.0;
            for (int i = 0; i < this.fractions.length; ++i) {
                this.fractions[i] = (currentSum += weights[i].doubleValue()) / weightSum;
            }
        }

        public void open() throws Exception {
            super.open();
            this.random = new Random(Tuple2.of((Object)this.initSeed, (Object)this.getRuntimeContext().getIndexOfThisSubtask()).hashCode());
        }

        public void processElement(StreamRecord<Row> streamRecord) throws Exception {
            int index;
            int searchResult = Arrays.binarySearch(this.fractions, this.random.nextDouble());
            int n = index = searchResult < 0 ? -searchResult - 2 : searchResult - 1;
            if (index == -1) {
                this.output.collect(streamRecord);
            } else {
                this.output.collect(this.outputTag[index], streamRecord);
            }
        }
    }
}

