/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
import org.apache.flink.ml.common.optimizer.SGD;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class LogisticRegression
implements Estimator<LogisticRegression, LogisticRegressionModel>,
LogisticRegressionParams<LogisticRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public LogisticRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public LogisticRegressionModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String classificationType = this.getMultiClass();
        Preconditions.checkArgument(("auto".equals(classificationType) || "binomial".equals(classificationType) ? 1 : 0) != 0, (Object)"Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator trainData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)dataPoint -> {
            boolean isBinomial;
            double weight = this.getWeightCol() == null ? 1.0 : ((Number)dataPoint.getField(this.getWeightCol())).doubleValue();
            double label = ((Number)dataPoint.getField(this.getLabelCol())).doubleValue();
            boolean bl = isBinomial = Double.compare(0.0, label) == 0 || Double.compare(1.0, label) == 0;
            if (!isBinomial) {
                throw new RuntimeException("Multinomial classification is not supported yet. Supported options: [auto, binomial].");
            }
            DenseVector features = ((Vector)dataPoint.getField(this.getFeaturesCol())).toDense();
            return new LabeledPointWithWeight(features, label, weight);
        });
        SingleOutputStreamOperator initModelData = DataStreamUtils.reduce((DataStream)trainData.map((MapFunction & Serializable)x -> x.getFeatures().size()), (ReduceFunction & Serializable)(t0, t1) -> {
            Preconditions.checkState((boolean)t0.equals(t1), (Object)"The training data should all have same dimensions.");
            return t0;
        }).map(DenseVector::new);
        SGD optimizer = new SGD(this.getMaxIter(), this.getLearningRate(), this.getGlobalBatchSize(), this.getTol(), this.getReg(), this.getElasticNet());
        DataStream<DenseVector> rawModelData = optimizer.optimize((DataStream<DenseVector>)initModelData, (DataStream<LabeledPointWithWeight>)trainData, BinaryLogisticLoss.INSTANCE);
        SingleOutputStreamOperator modelData = rawModelData.map((MapFunction & Serializable)vector -> new LogisticRegressionModelData(vector, 0L));
        LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static LogisticRegression load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (LogisticRegression)ReadWriteUtils.loadStageParam((String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}

