/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.benchmark;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.benchmark.BenchmarkResult;
import org.apache.flink.ml.benchmark.datagenerator.DataGenerator;
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

public class BenchmarkUtils {
    public static Map<String, Map<String, Map<String, ?>>> parseJsonFile(String path) throws IOException {
        FileInputStream inputStream = new FileInputStream(path);
        Map jsonMap = (Map)ReadWriteUtils.OBJECT_MAPPER.readValue((InputStream)inputStream, Map.class);
        Preconditions.checkArgument((jsonMap.containsKey("version") && jsonMap.get("version").equals(1) ? 1 : 0) != 0);
        HashMap result = new HashMap();
        for (Map.Entry entry : jsonMap.entrySet()) {
            if (((String)entry.getKey()).equals("version")) continue;
            result.put((String)entry.getKey(), (Map)entry.getValue());
        }
        return result;
    }

    public static BenchmarkResult runBenchmark(StreamTableEnvironment tEnv, String name, Map<String, Map<String, ?>> params) throws Exception {
        Stage stage = (Stage)ReadWriteUtils.instantiateWithParams(params.get("stage"));
        InputDataGenerator inputDataGenerator = (InputDataGenerator)ReadWriteUtils.instantiateWithParams(params.get("inputData"));
        DataGenerator modelDataGenerator = null;
        if (params.containsKey("modelData")) {
            modelDataGenerator = (DataGenerator)ReadWriteUtils.instantiateWithParams(params.get("modelData"));
        }
        return BenchmarkUtils.runBenchmark(tEnv, name, stage, inputDataGenerator, modelDataGenerator);
    }

    private static BenchmarkResult runBenchmark(StreamTableEnvironment tEnv, String name, Stage<?> stage, InputDataGenerator<?> inputDataGenerator, DataGenerator<?> modelDataGenerator) throws Exception {
        Table[] outputTables;
        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment((StreamTableEnvironment)tEnv);
        Table[] inputTables = inputDataGenerator.getData(tEnv);
        if (modelDataGenerator != null) {
            ((Model)stage).setModelData(modelDataGenerator.getData(tEnv));
        }
        if (stage instanceof Estimator) {
            outputTables = ((Estimator)stage).fit(inputTables).getModelData();
        } else if (stage instanceof AlgoOperator) {
            outputTables = ((AlgoOperator)stage).transform(inputTables);
        } else {
            throw new IllegalArgumentException("Unsupported Stage class " + stage.getClass());
        }
        for (Table table : outputTables) {
            tEnv.toDataStream(table).addSink(new CountingAndDiscardingSink());
        }
        JobExecutionResult executionResult = env.execute("Flink ML Benchmark Job " + name);
        double totalTimeMs = executionResult.getNetRuntime(TimeUnit.MILLISECONDS);
        long inputRecordNum = inputDataGenerator.getNumValues();
        double inputThroughput = (double)inputRecordNum * 1000.0 / totalTimeMs;
        long outputRecordNum = (Long)executionResult.getAccumulatorResult("numElements");
        double outputThroughput = (double)outputRecordNum * 1000.0 / totalTimeMs;
        return new BenchmarkResult(name, totalTimeMs, inputRecordNum, inputThroughput, outputRecordNum, outputThroughput);
    }

    private static class CountingAndDiscardingSink<T>
    extends RichSinkFunction<T> {
        public static final String COUNTER_NAME = "numElements";
        private static final long serialVersionUID = 1L;
        private final LongCounter numElementsCounter = new LongCounter();

        private CountingAndDiscardingSink() {
        }

        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            this.getRuntimeContext().addAccumulator(COUNTER_NAME, (Accumulator)this.numElementsCounter);
        }

        public void invoke(T value, SinkFunction.Context context) {
            this.numElementsCounter.add(1L);
        }
    }
}

