/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.hist_gradient_boosting;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.AttributeException;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PythonObject;
import sklearn.HasMultiDecisionFunctionField;
import sklearn.SkLearnClassifier;
import sklearn.compose.ColumnTransformer;
import sklearn.ensemble.hist_gradient_boosting.BaseLoss;
import sklearn.ensemble.hist_gradient_boosting.BinMapper;
import sklearn.ensemble.hist_gradient_boosting.HistGradientBoostingUtil;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;

public class HistGradientBoostingClassifier
extends SkLearnClassifier
implements HasMultiDecisionFunctionField {
    public HistGradientBoostingClassifier(String module, String name) {
        super(module, name);
    }

    public MiningModel encodeModel(Schema schema) {
        MiningModel miningModel;
        List<Number> baselinePredictions = this.getBaselinePrediction();
        PythonObject loss = this.getLoss();
        BinMapper binMapper = this.getBinMapper();
        int numberOfTreesPerIteration = this.getNumberOfTreesPerIteration();
        List<List<TreePredictor>> predictors = this.getPredictors();
        ColumnTransformer preprocessor = this.getPreprocessor();
        if (!predictors.isEmpty()) {
            ClassDictUtil.checkSize((int)numberOfTreesPerIteration, (Collection[])new Collection[]{predictors.get(0), baselinePredictions});
        }
        if (preprocessor != null) {
            schema = HistGradientBoostingUtil.preprocess(preprocessor, schema);
        }
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (numberOfTreesPerIteration == 1) {
            SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
            MiningModel model = HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, binMapper, baselinePredictions, 0, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)this.getMultiDecisionFunctionField(categoricalLabel.getValue(1)), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            miningModel = MiningModelUtil.createBinaryLogisticClassification((Model)model, (double)1.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)false, (Schema)schema);
        } else if (numberOfTreesPerIteration >= 3) {
            SchemaUtil.checkSize((int)numberOfTreesPerIteration, (DiscreteLabel)categoricalLabel);
            ArrayList<MiningModel> models = new ArrayList<MiningModel>();
            int columns = categoricalLabel.size();
            for (int i = 0; i < columns; ++i) {
                MiningModel model = HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, binMapper, baselinePredictions, i, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)this.getMultiDecisionFunctionField(categoricalLabel.getValue(i)), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
                models.add(model);
            }
            miningModel = MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)false, (Schema)schema);
        } else {
            throw new IllegalArgumentException();
        }
        this.encodePredictProbaOutput((Model)miningModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return miningModel;
    }

    public List<Number> getBaselinePrediction() {
        return this.getNumberArray("_baseline_prediction");
    }

    public BinMapper getBinMapper() {
        return (BinMapper)((Object)this.getOptional("_bin_mapper", BinMapper.class));
    }

    public PythonObject getLoss() {
        if (this.hasattr("loss_")) {
            this.get("loss_", BaseLoss.class);
        }
        try {
            return (PythonObject)this.get("_loss", BaseLoss.class);
        }
        catch (AttributeException ae) {
            return (PythonObject)this.get("_loss", sklearn.loss.BaseLoss.class);
        }
    }

    public Integer getNumberOfTreesPerIteration() {
        return this.getInteger("n_trees_per_iteration_");
    }

    public List<List<TreePredictor>> getPredictors() {
        return this.getList("_predictors", List.class);
    }

    public ColumnTransformer getPreprocessor() {
        return (ColumnTransformer)this.getOptional("_preprocessor", ColumnTransformer.class);
    }
}

