package org.jpmml.xgboost;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.io.DataInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
import org.jpmml.converter.visitors.TreeModelPruner;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

/* loaded from: input_file:org/jpmml/xgboost/Learner.class */
public class Learner implements BinaryLoadable, JSONLoadable {
    private float base_score;
    private int num_feature;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private int major_version;
    private int minor_version;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] metrics = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.xgboost.Learner$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/xgboost/Learner$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    @Override // org.jpmml.xgboost.BinaryLoadable
    public void loadBinary(XGBoostDataInput xGBoostDataInput) throws IOException {
        this.base_score = xGBoostDataInput.readFloat();
        this.num_feature = xGBoostDataInput.readInt();
        this.num_class = xGBoostDataInput.readInt();
        this.contain_extra_attrs = xGBoostDataInput.readInt();
        this.contain_eval_metrics = xGBoostDataInput.readInt();
        this.major_version = xGBoostDataInput.readInt();
        this.minor_version = xGBoostDataInput.readInt();
        if (this.major_version < 0 || this.major_version > 1) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        xGBoostDataInput.readReserved(27);
        this.obj = parseObjective(xGBoostDataInput.readString());
        if (this.major_version >= 1) {
            this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        } else {
            this.base_score = this.base_score;
        }
        this.gbtree = parseGradientBooster(xGBoostDataInput.readString());
        this.gbtree.loadBinary(xGBoostDataInput);
        if (this.contain_extra_attrs != 0) {
            this.attributes = xGBoostDataInput.readStringMap();
        }
        if (this.major_version >= 1) {
            return;
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                xGBoostDataInput.readString();
            } catch (EOFException e) {
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = xGBoostDataInput.readStringVector();
        }
    }

    @Override // org.jpmml.xgboost.JSONLoadable
    public void loadJSON(JsonObject jsonObject) {
        JsonArray asJsonArray = jsonObject.getAsJsonArray("version");
        this.major_version = asJsonArray.get(0).getAsInt();
        this.minor_version = asJsonArray.get(1).getAsInt();
        if (this.major_version < 1 || this.minor_version < 3) {
            throw new IllegalArgumentException();
        }
        JsonObject asJsonObject = jsonObject.getAsJsonObject("learner");
        JsonObject asJsonObject2 = asJsonObject.getAsJsonObject("learner_model_param");
        this.base_score = asJsonObject2.getAsJsonPrimitive("base_score").getAsFloat();
        this.num_feature = asJsonObject2.getAsJsonPrimitive("num_feature").getAsInt();
        this.num_class = asJsonObject2.getAsJsonPrimitive("num_class").getAsInt();
        this.obj = parseObjective(asJsonObject.getAsJsonObject("objective").getAsJsonPrimitive("name").getAsString());
        this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        JsonObject asJsonObject3 = asJsonObject.getAsJsonObject("gradient_booster");
        this.gbtree = parseGradientBooster(asJsonObject3.getAsJsonPrimitive("name").getAsString());
        this.gbtree.loadJSON(asJsonObject3);
    }

    public <DIS extends InputStream & DataInput> void loadBinary(DIS dis, String str) throws IOException {
        boolean consumeHeader = consumeHeader(dis, XGBoostUtil.SERIALIZATION_HEADER);
        if (consumeHeader && dis.readLong() < 0) {
            throw new IOException();
        }
        if (consumeHeader(dis, XGBoostUtil.BINF_HEADER)) {
        }
        XGBoostDataInput xGBoostDataInput = new XGBoostDataInput(dis, str);
        Throwable th = null;
        try {
            loadBinary(xGBoostDataInput);
            if (!consumeHeader && dis.read() != -1) {
                throw new IOException();
            }
            if (xGBoostDataInput != null) {
                if (0 == 0) {
                    xGBoostDataInput.close();
                    return;
                }
                try {
                    xGBoostDataInput.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (xGBoostDataInput != null) {
                if (0 != 0) {
                    try {
                        xGBoostDataInput.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    xGBoostDataInput.close();
                }
            }
            throw th3;
        }
    }

    public void loadJSON(InputStream inputStream, String str, String str2) throws IOException {
        JsonParser jsonParser = new JsonParser();
        if (str == null) {
            str = "UTF-8";
        }
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream, str);
        Throwable th = null;
        try {
            JsonObject asJsonObject = jsonParser.parse(inputStreamReader).getAsJsonObject();
            String[] split = str2.split("\\.");
            for (int i = 0; i < split.length; i++) {
                String str3 = split[i];
                if (i != 0) {
                    asJsonObject = asJsonObject.getAsJsonObject(str3);
                } else if (!"$".equals(str3)) {
                    throw new IllegalArgumentException(str2);
                }
            }
            loadJSON(asJsonObject);
            if (inputStream.read() != -1) {
                throw new IOException();
            }
            if (inputStreamReader != null) {
                if (0 == 0) {
                    inputStreamReader.close();
                    return;
                }
                try {
                    inputStreamReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (inputStreamReader != null) {
                if (0 != 0) {
                    try {
                        inputStreamReader.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    inputStreamReader.close();
                }
            }
            throw th3;
        }
    }

    public Schema encodeSchema(FieldName fieldName, List<String> list, FeatureMap featureMap, XGBoostEncoder xGBoostEncoder) {
        if (fieldName == null) {
            fieldName = FieldName.create("_target");
        }
        return new Schema(xGBoostEncoder, this.obj.encodeLabel(fieldName, list, xGBoostEncoder), featureMap.encodeFeatures(xGBoostEncoder));
    }

    public Schema toXGBoostSchema(final boolean z, Schema schema) {
        return schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: org.jpmml.xgboost.Learner.1
            @Override // java.util.function.Function
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    return (BinaryFeature) feature;
                }
                if (feature instanceof MissingValueFeature) {
                    return (MissingValueFeature) feature;
                }
                if ((feature instanceof ThresholdFeature) && !z) {
                    return (ThresholdFeature) feature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$DataType[dataType.ordinal()]) {
                    case 1:
                    case 2:
                        break;
                    case 3:
                        continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                        break;
                    default:
                        throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                }
                return continuousFeature;
            }
        });
    }

    public PMML encodePMML(Map<String, ?> map, FieldName fieldName, List<String> list, FeatureMap featureMap) {
        XGBoostEncoder xGBoostEncoder = new XGBoostEncoder();
        Boolean bool = (Boolean) map.get(HasXGBoostOptions.OPTION_NAN_AS_MISSING);
        PMML encodePMML = xGBoostEncoder.encodePMML(encodeMiningModel(map, encodeSchema(fieldName, list, featureMap, xGBoostEncoder)));
        if (Boolean.TRUE.equals(bool)) {
            new NaNAsMissingDecorator().applyTo(encodePMML);
        }
        return encodePMML;
    }

    public MiningModel encodeMiningModel(Map<String, ?> map, Schema schema) {
        Boolean bool = (Boolean) map.get(HasXGBoostOptions.OPTION_COMPACT);
        Boolean bool2 = (Boolean) map.get(HasXGBoostOptions.OPTION_NUMERIC);
        Boolean bool3 = (Boolean) map.get(HasXGBoostOptions.OPTION_PRUNE);
        Integer num = (Integer) map.get(HasXGBoostOptions.OPTION_NTREE_LIMIT);
        if (bool2 == null) {
            bool2 = Boolean.TRUE;
        }
        MiningModel algorithmName = this.gbtree.encodeMiningModel(this.obj, this.base_score, num, bool2.booleanValue(), schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(bool)) {
            if (Boolean.FALSE.equals(bool2)) {
                throw new IllegalArgumentException("Conflicting XGBoost options");
            }
            new TreeModelCompactor().applyTo(algorithmName);
        }
        if (Boolean.TRUE.equals(bool3)) {
            new TreeModelPruner().applyTo(algorithmName);
        }
        return algorithmName;
    }

    public int num_feature() {
        return this.num_feature;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }

    private GBTree parseGradientBooster(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1252091143:
                if (str.equals("gbtree")) {
                    z = false;
                    break;
                }
                break;
            case 3075967:
                if (str.equals("dart")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new GBTree();
            case true:
                return new Dart();
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private ObjFunction parseObjective(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -2064169338:
                if (str.equals("reg:logistic")) {
                    z = 3;
                    break;
                }
                break;
            case -1556803417:
                if (str.equals("reg:squarederror")) {
                    z = true;
                    break;
                }
                break;
            case -1467986345:
                if (str.equals("reg:tweedie")) {
                    z = 5;
                    break;
                }
                break;
            case -1278983871:
                if (str.equals("reg:gamma")) {
                    z = 4;
                    break;
                }
                break;
            case -1050549691:
                if (str.equals("reg:squaredlogerror")) {
                    z = 2;
                    break;
                }
                break;
            case -843238197:
                if (str.equals("reg:linear")) {
                    z = false;
                    break;
                }
                break;
            case -716747662:
                if (str.equals("binary:hinge")) {
                    z = 7;
                    break;
                }
                break;
            case -676051160:
                if (str.equals("rank:ndcg")) {
                    z = 10;
                    break;
                }
                break;
            case -64711776:
                if (str.equals("multi:softprob")) {
                    z = 13;
                    break;
                }
                break;
            case 76276236:
                if (str.equals("rank:pairwise")) {
                    z = 11;
                    break;
                }
                break;
            case 255285518:
                if (str.equals("rank:map")) {
                    z = 9;
                    break;
                }
                break;
            case 921599518:
                if (str.equals("count:poisson")) {
                    z = 6;
                    break;
                }
                break;
            case 1161133497:
                if (str.equals("binary:logistic")) {
                    z = 8;
                    break;
                }
                break;
            case 1937571769:
                if (str.equals("multi:softmax")) {
                    z = 12;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
            case true:
                return new LinearRegression(str);
            case true:
                return new LogisticRegression(str);
            case true:
            case true:
                return new GeneralizedLinearRegression(str);
            case true:
                return new PoissonRegression(str);
            case true:
                return new HingeClassification(str);
            case true:
                return new BinomialLogisticRegression(str);
            case true:
            case true:
            case true:
                return new LambdaMART(str);
            case true:
            case true:
                return new MultinomialLogisticRegression(str, this.num_class);
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static <DIS extends InputStream & DataInput> boolean consumeHeader(DIS dis, String str) throws IOException {
        byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
        byte[] bArr = new byte[bytes.length];
        dis.mark(bArr.length);
        dis.readFully(bArr);
        boolean equals = Arrays.equals(bytes, bArr);
        if (!equals) {
            dis.reset();
        }
        return equals;
    }
}
