/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.regression.linearregression;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.lossfunc.LeastSquareLoss;
import org.apache.flink.ml.common.optimizer.SGD;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
import org.apache.flink.ml.regression.linearregression.LinearRegressionParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class LinearRegression
implements Estimator<LinearRegression, LinearRegressionModel>,
LinearRegressionParams<LinearRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public LinearRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public LinearRegressionModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator trainData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)dataPoint -> {
            double weight = this.getWeightCol() == null ? 1.0 : ((Number)dataPoint.getField(this.getWeightCol())).doubleValue();
            double label = ((Number)dataPoint.getField(this.getLabelCol())).doubleValue();
            DenseVector features = ((Vector)dataPoint.getField(this.getFeaturesCol())).toDense();
            return new LabeledPointWithWeight(features, label, weight);
        });
        SingleOutputStreamOperator initModelData = DataStreamUtils.reduce((DataStream)trainData.map((MapFunction & Serializable)x -> x.getFeatures().size()), (ReduceFunction & Serializable)(t0, t1) -> {
            Preconditions.checkState((boolean)t0.equals(t1), (Object)"The training data should all have same dimensions.");
            return t0;
        }).map(DenseVector::new);
        SGD optimizer = new SGD(this.getMaxIter(), this.getLearningRate(), this.getGlobalBatchSize(), this.getTol(), this.getReg(), this.getElasticNet());
        DataStream<DenseVector> rawModelData = optimizer.optimize((DataStream<DenseVector>)initModelData, (DataStream<LabeledPointWithWeight>)trainData, LeastSquareLoss.INSTANCE);
        SingleOutputStreamOperator modelData = rawModelData.map(LinearRegressionModelData::new);
        LinearRegressionModel model = new LinearRegressionModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static LinearRegression load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (LinearRegression)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}

