package org.jpmml.xgboost;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitor;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.model.visitors.DataDictionaryCleaner;
import org.jpmml.model.visitors.MiningSchemaCleaner;

/* loaded from: input_file:org/jpmml/xgboost/Learner.class */
public class Learner {
    private float base_score;
    private int num_features;
    private int num_class;
    private ObjFunction obj;
    private GBTree gbtree;

    public void load(XGBoostDataInput xGBoostDataInput) throws IOException {
        this.base_score = xGBoostDataInput.readFloat();
        this.num_features = xGBoostDataInput.readInt();
        this.num_class = xGBoostDataInput.readInt();
        xGBoostDataInput.readReserved(31);
        String readString = xGBoostDataInput.readString();
        boolean z = -1;
        switch (readString.hashCode()) {
            case -2064169338:
                if (readString.equals("reg:logistic")) {
                    z = true;
                    break;
                }
                break;
            case -843238197:
                if (readString.equals("reg:linear")) {
                    z = false;
                    break;
                }
                break;
            case -64711776:
                if (readString.equals("multi:softprob")) {
                    z = 5;
                    break;
                }
                break;
            case 921599518:
                if (readString.equals("count:poisson")) {
                    z = 2;
                    break;
                }
                break;
            case 1161133497:
                if (readString.equals("binary:logistic")) {
                    z = 3;
                    break;
                }
                break;
            case 1937571769:
                if (readString.equals("multi:softmax")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.obj = new LinearRegression();
                break;
            case true:
                this.obj = new LogisticRegression();
                break;
            case true:
                this.obj = new PoissonRegression();
                break;
            case true:
                this.obj = new LogisticClassification();
                break;
            case true:
            case true:
                this.obj = new SoftMaxClassification(this.num_class);
                break;
            default:
                throw new IllegalArgumentException(readString);
        }
        String readString2 = xGBoostDataInput.readString();
        boolean z2 = -1;
        switch (readString2.hashCode()) {
            case -1252091143:
                if (readString2.equals("gbtree")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                this.gbtree = new GBTree();
                this.gbtree.load(xGBoostDataInput);
                return;
            default:
                throw new IllegalArgumentException(readString2);
        }
    }

    public PMML encodePMML(FieldName fieldName, List<String> list, FeatureMap featureMap) {
        if (fieldName == null) {
            fieldName = FieldName.create("_target");
        }
        if (this.obj instanceof Classification) {
            Classification classification = (Classification) this.obj;
            if (list == null) {
                list = createTargetCategories(classification.getNumClass());
            } else if (list.size() != classification.getNumClass()) {
                throw new IllegalArgumentException();
            }
        } else if (list != null) {
            throw new IllegalArgumentException();
        }
        DataField dataField = new DataField(fieldName, this.obj.getOpType(), this.obj.getDataType());
        if (list != null) {
            dataField.getValues().addAll(PMMLUtil.createValues(list));
        }
        Model encodeMiningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, featureMap.createSchema(fieldName, list));
        ArrayList arrayList = new ArrayList();
        arrayList.add(dataField);
        arrayList.addAll(featureMap.getDataFields());
        PMML addModels = new PMML("4.3", PMMLUtil.createHeader(Learner.class), new DataDictionary(arrayList)).addModels(new Model[]{encodeMiningModel});
        Iterator it = Arrays.asList(new MiningSchemaCleaner(), new DataDictionaryCleaner()).iterator();
        while (it.hasNext()) {
            ((Visitor) it.next()).applyTo(addModels);
        }
        return addModels;
    }

    public float getBaseScore() {
        return this.base_score;
    }

    public int getNumClass() {
        return this.num_class;
    }

    public int getNumFeatures() {
        return this.num_features;
    }

    public ObjFunction getObj() {
        return this.obj;
    }

    public GBTree getGBTree() {
        return this.gbtree;
    }

    private static List<String> createTargetCategories(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(String.valueOf(i2));
        }
        return arrayList;
    }
}
