package org.jpmml.xgboost;

import java.io.EOFException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BaseNFeature;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

/* loaded from: input_file:org/jpmml/xgboost/Learner.class */
public class Learner implements Loadable {
    private float base_score;
    private int num_features;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private int major_version;
    private int minor_version;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] metrics = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.xgboost.Learner$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/xgboost/Learner$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    @Override // org.jpmml.xgboost.Loadable
    public void load(XGBoostDataInput xGBoostDataInput) throws IOException {
        this.base_score = xGBoostDataInput.readFloat();
        this.num_features = xGBoostDataInput.readInt();
        this.num_class = xGBoostDataInput.readInt();
        this.contain_extra_attrs = xGBoostDataInput.readInt();
        this.contain_eval_metrics = xGBoostDataInput.readInt();
        this.major_version = xGBoostDataInput.readInt();
        this.minor_version = xGBoostDataInput.readInt();
        if (this.major_version < 0 || this.major_version > 1) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        xGBoostDataInput.readReserved(27);
        String readString = xGBoostDataInput.readString();
        boolean z = -1;
        switch (readString.hashCode()) {
            case -2064169338:
                if (readString.equals("reg:logistic")) {
                    z = 3;
                    break;
                }
                break;
            case -1556803417:
                if (readString.equals("reg:squarederror")) {
                    z = true;
                    break;
                }
                break;
            case -1467986345:
                if (readString.equals("reg:tweedie")) {
                    z = 5;
                    break;
                }
                break;
            case -1278983871:
                if (readString.equals("reg:gamma")) {
                    z = 4;
                    break;
                }
                break;
            case -1050549691:
                if (readString.equals("reg:squaredlogerror")) {
                    z = 2;
                    break;
                }
                break;
            case -843238197:
                if (readString.equals("reg:linear")) {
                    z = false;
                    break;
                }
                break;
            case -716747662:
                if (readString.equals("binary:hinge")) {
                    z = 7;
                    break;
                }
                break;
            case -676051160:
                if (readString.equals("rank:ndcg")) {
                    z = 10;
                    break;
                }
                break;
            case -64711776:
                if (readString.equals("multi:softprob")) {
                    z = 13;
                    break;
                }
                break;
            case 76276236:
                if (readString.equals("rank:pairwise")) {
                    z = 11;
                    break;
                }
                break;
            case 255285518:
                if (readString.equals("rank:map")) {
                    z = 9;
                    break;
                }
                break;
            case 921599518:
                if (readString.equals("count:poisson")) {
                    z = 6;
                    break;
                }
                break;
            case 1161133497:
                if (readString.equals("binary:logistic")) {
                    z = 8;
                    break;
                }
                break;
            case 1937571769:
                if (readString.equals("multi:softmax")) {
                    z = 12;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
            case true:
                this.obj = new LinearRegression();
                break;
            case true:
                this.obj = new LogisticRegression();
                break;
            case true:
            case true:
                this.obj = new GeneralizedLinearRegression();
                break;
            case true:
                this.obj = new PoissonRegression();
                break;
            case true:
                this.obj = new HingeClassification();
                break;
            case true:
                this.obj = new BinomialLogisticRegression();
                break;
            case true:
            case true:
            case true:
                this.obj = new LambdaMART();
                break;
            case true:
            case true:
                this.obj = new MultinomialLogisticRegression(this.num_class);
                break;
            default:
                throw new IllegalArgumentException(readString);
        }
        if (this.major_version >= 1) {
            this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        } else {
            this.base_score = this.base_score;
        }
        String readString2 = xGBoostDataInput.readString();
        boolean z2 = -1;
        switch (readString2.hashCode()) {
            case -1252091143:
                if (readString2.equals("gbtree")) {
                    z2 = false;
                    break;
                }
                break;
            case 3075967:
                if (readString2.equals("dart")) {
                    z2 = true;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                this.gbtree = new GBTree();
                break;
            case true:
                this.gbtree = new Dart();
                break;
            default:
                throw new IllegalArgumentException(readString2);
        }
        this.gbtree.load(xGBoostDataInput);
        if (this.contain_extra_attrs != 0) {
            this.attributes = xGBoostDataInput.readStringMap();
        }
        if (this.major_version >= 1) {
            return;
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                xGBoostDataInput.readString();
            } catch (EOFException e) {
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = xGBoostDataInput.readStringVector();
        }
    }

    public Schema encodeSchema(FieldName fieldName, List<String> list, FeatureMap featureMap, XGBoostEncoder xGBoostEncoder) {
        if (fieldName == null) {
            fieldName = FieldName.create("_target");
        }
        return new Schema(xGBoostEncoder, this.obj.encodeLabel(fieldName, list, xGBoostEncoder), featureMap.encodeFeatures(xGBoostEncoder));
    }

    public Schema toXGBoostSchema(Schema schema) {
        return schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: org.jpmml.xgboost.Learner.1
            @Override // java.util.function.Function
            public Feature apply(Feature feature) {
                if (feature instanceof BaseNFeature) {
                    return (BaseNFeature) feature;
                }
                if (feature instanceof BinaryFeature) {
                    return (BinaryFeature) feature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$DataType[dataType.ordinal()]) {
                    case 1:
                    case 2:
                        break;
                    case 3:
                        continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                        break;
                    default:
                        throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                }
                return continuousFeature;
            }
        });
    }

    public PMML encodePMML(Map<String, ?> map, FieldName fieldName, List<String> list, FeatureMap featureMap) {
        XGBoostEncoder xGBoostEncoder = new XGBoostEncoder();
        Boolean bool = (Boolean) map.get(HasXGBoostOptions.OPTION_NAN_AS_MISSING);
        PMML encodePMML = xGBoostEncoder.encodePMML(encodeMiningModel(map, encodeSchema(fieldName, list, featureMap, xGBoostEncoder)));
        if (Boolean.TRUE.equals(bool)) {
            new NaNAsMissingDecorator().applyTo(encodePMML);
        }
        return encodePMML;
    }

    public MiningModel encodeMiningModel(Map<String, ?> map, Schema schema) {
        Boolean bool = (Boolean) map.get(HasXGBoostOptions.OPTION_COMPACT);
        MiningModel algorithmName = this.gbtree.encodeMiningModel(this.obj, this.base_score, (Integer) map.get(HasXGBoostOptions.OPTION_NTREE_LIMIT), schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(bool)) {
            new TreeModelCompactor().applyTo(algorithmName);
        }
        return algorithmName;
    }

    public int num_features() {
        return this.num_features;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }
}
