/*
 * Decompiled with CFR 0.152.
 */
package org.flinkextended.examples.tensorflow.linear;

import java.net.URL;
import java.util.concurrent.ExecutionException;
import org.apache.flink.api.java.utils.MultipleParameterTool;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.types.AbstractDataType;
import org.flinkextended.flink.ml.tensorflow.client.TFClusterConfig;
import org.flinkextended.flink.ml.tensorflow.client.TFUtils;

public class Linear {
    private static final String MODEL_PATH = "model-path";
    private static final String EPOCH = "epoch";
    private static final String SAMPLE_COUNT = "sample-count";
    private static final String MODE = "mode";
    private static final String INFERENCE_OUTPUT_PATH = "inference-output-path";

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        MultipleParameterTool params = MultipleParameterTool.fromArgs((String[])args);
        String mode = params.get(MODE, "train");
        String inferenceOutputPath = params.get(INFERENCE_OUTPUT_PATH, "/tmp/linear/output.csv");
        String modelPath = params.get(MODEL_PATH, String.format("/tmp/linear/%s", System.currentTimeMillis()));
        Integer epoch = Integer.valueOf(params.get(EPOCH, "1"));
        Integer sampleCount = Integer.valueOf(params.get(SAMPLE_COUNT, "512000"));
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(2);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create((StreamExecutionEnvironment)env);
        StreamStatementSet statementSet = tEnv.createStatementSet();
        Table sample = tEnv.from(TableDescriptor.forConnector((String)"datagen").schema(Schema.newBuilder().column("x", (AbstractDataType)DataTypes.DOUBLE()).columnByExpression("y", "2 * x + 1").build()).option("fields.x.min", "0").option("fields.x.max", "1").option("number-of-rows", String.valueOf(sampleCount)).build());
        if ("train".equals(mode)) {
            System.out.printf("Model will be trained with %d samples for %d epochs and saved at: %s%n", sampleCount, epoch, modelPath);
            Linear.train(modelPath, epoch, statementSet, sample);
        } else if ("inference".equals(mode)) {
            System.out.printf("Inference with model at %s, output will be at %s%n", modelPath, inferenceOutputPath);
            Linear.inference(modelPath, statementSet, sample, inferenceOutputPath);
        } else {
            throw new RuntimeException(String.format("Unknown mode %s", mode));
        }
    }

    private static void inference(String modelPath, StreamStatementSet statementSet, Table sample, String inferenceOutputPath) throws ExecutionException, InterruptedException {
        Table table = sample.dropColumns(new Expression[]{Expressions.$((String)"y")});
        TFClusterConfig config = ((TFClusterConfig.Builder)((TFClusterConfig.Builder)((TFClusterConfig.Builder)((TFClusterConfig.Builder)((TFClusterConfig.Builder)TFClusterConfig.newBuilder().setWorkerCount(Integer.valueOf(2)).setNodeEntry(Linear.getScriptPathFromResources(), "inference")).setProperty("storage_type", "local_file")).setProperty("model_save_path", modelPath)).setProperty("input_types", "FLOAT_64")).setProperty("output_types", "FLOAT_64,FLOAT_64")).build();
        Table output = TFUtils.inference((StatementSet)statementSet, (Table)table, (TFClusterConfig)config, (Schema)Schema.newBuilder().column("x", (AbstractDataType)DataTypes.DOUBLE()).column("y", (AbstractDataType)DataTypes.DOUBLE()).build());
        statementSet.addInsert(TableDescriptor.forConnector((String)"filesystem").format("csv").option("path", inferenceOutputPath).build(), output);
        statementSet.execute().await();
    }

    private static void train(String modelPath, Integer epoch, StreamStatementSet statementSet, Table sample) throws InterruptedException, ExecutionException {
        TFClusterConfig config = ((TFClusterConfig.Builder)((TFClusterConfig.Builder)((TFClusterConfig.Builder)((TFClusterConfig.Builder)TFClusterConfig.newBuilder().setWorkerCount(Integer.valueOf(2)).setNodeEntry(Linear.getScriptPathFromResources(), "train")).setProperty("storage_type", "local_file")).setProperty("model_save_path", modelPath)).setProperty("input_types", "FLOAT_64,FLOAT_64")).build();
        TFUtils.train((StatementSet)statementSet, (Table)sample, (TFClusterConfig)config, (Integer)epoch);
        statementSet.execute().await();
    }

    private static String getScriptPathFromResources() {
        URL resource = Thread.currentThread().getContextClassLoader().getResource("linear.py");
        if (resource == null) {
            throw new RuntimeException(String.format("Fail to find resource %s", "linear.py"));
        }
        return resource.getPath();
    }
}

