package ml.dmlc.xgboost4j.java;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/Booster.class */
public class Booster implements Serializable, KryoSerializable {
    private static final Log logger = LogFactory.getLog(Booster.class);
    private long handle = 0;
    private int version = 0;

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/Booster$FeatureImportanceType.class */
    public static class FeatureImportanceType {
        public static final String WEIGHT = "weight";
        public static final String GAIN = "gain";
        public static final String COVER = "cover";
        public static final String TOTAL_GAIN = "total_gain";
        public static final String TOTAL_COVER = "total_cover";
        public static final Set<String> ACCEPTED_TYPES = new HashSet(Arrays.asList(WEIGHT, GAIN, COVER, TOTAL_GAIN, TOTAL_COVER));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Booster(Map<String, Object> map, DMatrix[] dMatrixArr) throws XGBoostError {
        init(dMatrixArr);
        setParam("seed", "0");
        setParams(map);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Booster loadModel(String str) throws XGBoostError {
        if (str == null) {
            throw new NullPointerException("modelPath : null");
        }
        Booster booster = new Booster(new HashMap(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(booster.handle, str));
        return booster;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Booster loadModel(InputStream inputStream) throws XGBoostError, IOException {
        byte[] bArr = new byte[1048576];
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                inputStream.close();
                Booster booster = new Booster(new HashMap(), new DMatrix[0]);
                XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(booster.handle, byteArrayOutputStream.toByteArray()));
                return booster;
            }
            byteArrayOutputStream.write(bArr, 0, read);
        }
    }

    public final void setParam(String str, Object obj) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(this.handle, str, obj.toString()));
    }

    public void setParams(Map<String, Object> map) throws XGBoostError {
        if (map != null) {
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                setParam(entry.getKey(), entry.getValue().toString());
            }
        }
    }

    public void update(DMatrix dMatrix, int i) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(this.handle, i, dMatrix.getHandle()));
    }

    public void update(DMatrix dMatrix, IObjective iObjective) throws XGBoostError {
        List<float[]> gradient = iObjective.getGradient(predict(dMatrix, true, 0, false, false), dMatrix);
        boost(dMatrix, gradient.get(0), gradient.get(1));
    }

    public void boost(DMatrix dMatrix, float[] fArr, float[] fArr2) throws XGBoostError {
        if (fArr.length != fArr2.length) {
            throw new AssertionError(String.format("grad/hess length mismatch %s / %s", Integer.valueOf(fArr.length), Integer.valueOf(fArr2.length)));
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(this.handle, dMatrix.getHandle(), fArr, fArr2));
    }

    public String evalSet(DMatrix[] dMatrixArr, String[] strArr, int i) throws XGBoostError {
        String[] strArr2 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(this.handle, i, dmatrixsToHandles(dMatrixArr), strArr, strArr2));
        return strArr2[0];
    }

    public String evalSet(DMatrix[] dMatrixArr, String[] strArr, int i, float[] fArr) throws XGBoostError {
        String evalSet = evalSet(dMatrixArr, strArr, i);
        String[] split = evalSet.split("\t");
        for (int i2 = 1; i2 < split.length; i2++) {
            fArr[i2 - 1] = Float.valueOf(split[i2].split(":")[1]).floatValue();
        }
        return evalSet;
    }

    public String evalSet(DMatrix[] dMatrixArr, String[] strArr, IEvaluation iEvaluation) throws XGBoostError {
        return evalSet(dMatrixArr, strArr, iEvaluation, new float[strArr.length]);
    }

    public String evalSet(DMatrix[] dMatrixArr, String[] strArr, IEvaluation iEvaluation, float[] fArr) throws XGBoostError {
        String str = "";
        for (int i = 0; i < strArr.length; i++) {
            String str2 = strArr[i];
            DMatrix dMatrix = dMatrixArr[i];
            float eval = iEvaluation.eval(predict(dMatrix), dMatrix);
            str = str + String.format("\t%s-%s:%f", str2, iEvaluation.getMetric(), Float.valueOf(eval));
            fArr[i] = eval;
        }
        return str;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [float[], float[][]] */
    private synchronized float[][] predict(DMatrix dMatrix, boolean z, int i, boolean z2, boolean z3) throws XGBoostError {
        int i2 = z ? 1 : 0;
        if (z2) {
            i2 = 2;
        }
        if (z3) {
            i2 = 4;
        }
        ?? r0 = new float[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(this.handle, dMatrix.getHandle(), i2, i, r0));
        int rowNum = (int) dMatrix.rowNum();
        int length = r0[0].length / rowNum;
        float[][] fArr = new float[rowNum][length];
        for (int i3 = 0; i3 < r0[0].length; i3++) {
            fArr[i3 / length][i3 % length] = r0[0][i3];
        }
        return fArr;
    }

    public float[][] predictLeaf(DMatrix dMatrix, int i) throws XGBoostError {
        return predict(dMatrix, false, i, true, false);
    }

    public float[][] predictContrib(DMatrix dMatrix, int i) throws XGBoostError {
        return predict(dMatrix, false, i, true, true);
    }

    public float[][] predict(DMatrix dMatrix) throws XGBoostError {
        return predict(dMatrix, false, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean z) throws XGBoostError {
        return predict(dMatrix, z, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean z, int i) throws XGBoostError {
        return predict(dMatrix, z, i, false, false);
    }

    public void saveModel(String str) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(this.handle, str));
    }

    public void saveModel(OutputStream outputStream) throws XGBoostError, IOException {
        outputStream.write(toByteArray());
        outputStream.close();
    }

    public String[] getModelDump(String str, boolean z) throws XGBoostError {
        return getModelDump(str, z, "text");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [java.lang.String[], java.lang.String[][]] */
    public String[] getModelDump(String str, boolean z, String str2) throws XGBoostError {
        int i = 0;
        if (str == null) {
            str = "";
        }
        if (z) {
            i = 1;
        }
        if (str2 == null) {
            str2 = "text";
        }
        ?? r0 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, str, i, str2, r0));
        return r0[0];
    }

    public String[] getModelDump(String[] strArr, boolean z) throws XGBoostError {
        return getModelDump(strArr, z, "text");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.String[], java.lang.String[][]] */
    public String[] getModelDump(String[] strArr, boolean z, String str) throws XGBoostError {
        int i = 0;
        if (z) {
            i = 1;
        }
        if (str == null) {
            str = "text";
        }
        ?? r0 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures(this.handle, strArr, i, str, r0));
        return r0[0];
    }

    public Map<String, Integer> getFeatureScore(String[] strArr) throws XGBoostError {
        return getFeatureWeightsFromModel(getModelDump(strArr, false));
    }

    public Map<String, Integer> getFeatureScore(String str) throws XGBoostError {
        return getFeatureWeightsFromModel(getModelDump(str, false));
    }

    private Map<String, Integer> getFeatureWeightsFromModel(String[] strArr) throws XGBoostError {
        HashMap hashMap = new HashMap();
        for (String str : strArr) {
            for (String str2 : str.split("\n")) {
                String[] split = str2.split("\\[");
                if (split.length != 1) {
                    String str3 = split[1].split("\\]")[0].split("<")[0];
                    if (hashMap.containsKey(str3)) {
                        hashMap.put(str3, Integer.valueOf(1 + ((Integer) hashMap.get(str3)).intValue()));
                    } else {
                        hashMap.put(str3, 1);
                    }
                }
            }
        }
        return hashMap;
    }

    public Map<String, Double> getScore(String[] strArr, String str) throws XGBoostError {
        return getFeatureImportanceFromModel(getModelDump(strArr, true), str);
    }

    public Map<String, Double> getScore(String str, String str2) throws XGBoostError {
        return getFeatureImportanceFromModel(getModelDump(str, true), str2);
    }

    private Map<String, Double> getFeatureImportanceFromModel(String[] strArr, String str) throws XGBoostError {
        if (!FeatureImportanceType.ACCEPTED_TYPES.contains(str)) {
            throw new AssertionError(String.format("Importance type %s is not supported", str));
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        if (str == FeatureImportanceType.WEIGHT) {
            Iterator<String> it = getFeatureWeightsFromModel(strArr).keySet().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next(), new Double(r0.get(r0).intValue()));
            }
            return hashMap;
        }
        String str2 = (str == FeatureImportanceType.COVER || str == FeatureImportanceType.TOTAL_COVER) ? "cover=" : "gain=";
        for (String str3 : strArr) {
            for (String str4 : str3.split("\n")) {
                String[] split = str4.split("\\[");
                if (split.length != 1) {
                    String[] split2 = split[1].split("\\]");
                    Double valueOf = Double.valueOf(Double.parseDouble(split2[1].split(str2)[1].split(",")[0]));
                    String str5 = split2[0].split("<")[0];
                    if (hashMap.containsKey(str5)) {
                        hashMap.put(str5, Double.valueOf(valueOf.doubleValue() + ((Double) hashMap.get(str5)).doubleValue()));
                        hashMap2.put(str5, Double.valueOf(1.0d + ((Double) hashMap2.get(str5)).doubleValue()));
                    } else {
                        hashMap.put(str5, valueOf);
                        hashMap2.put(str5, Double.valueOf(1.0d));
                    }
                }
            }
        }
        if (str == FeatureImportanceType.COVER || str == FeatureImportanceType.GAIN) {
            for (String str6 : hashMap.keySet()) {
                hashMap.put(str6, Double.valueOf(((Double) hashMap.get(str6)).doubleValue() / ((Double) hashMap2.get(str6)).doubleValue()));
            }
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [java.lang.String[], java.lang.String[][]] */
    private String[] getDumpInfo(boolean z) throws XGBoostError {
        int i = 0;
        if (z) {
            i = 1;
        }
        ?? r0 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, "", i, "text", r0));
        return r0[0];
    }

    public int getVersion() {
        return this.version;
    }

    public void setVersion(int i) {
        this.version = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [byte[], byte[][]] */
    public byte[] toByteArray() throws XGBoostError {
        ?? r0 = new byte[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, r0));
        return r0[0];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int loadRabitCheckpoint() throws XGBoostError {
        int[] iArr = new int[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, iArr));
        this.version = iArr[0];
        return this.version;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void saveRabitCheckpoint() throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
        this.version++;
    }

    private void init(DMatrix[] dMatrixArr) throws XGBoostError {
        long[] jArr = null;
        if (dMatrixArr != null) {
            jArr = dmatrixsToHandles(dMatrixArr);
        }
        long[] jArr2 = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(jArr, jArr2));
        this.handle = jArr2[0];
    }

    private static long[] dmatrixsToHandles(DMatrix[] dMatrixArr) {
        long[] jArr = new long[dMatrixArr.length];
        for (int i = 0; i < dMatrixArr.length; i++) {
            jArr[i] = dMatrixArr[i].getHandle();
        }
        return jArr;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        try {
            objectOutputStream.writeInt(this.version);
            objectOutputStream.writeObject(toByteArray());
        } catch (XGBoostError e) {
            e.printStackTrace();
            logger.error(e.getMessage());
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        try {
            init(null);
            this.version = objectInputStream.readInt();
            XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, (byte[]) objectInputStream.readObject()));
        } catch (XGBoostError e) {
            e.printStackTrace();
            logger.error(e.getMessage());
        }
    }

    protected void finalize() throws Throwable {
        super.finalize();
        dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0) {
            XGBoostJNI.XGBoosterFree(this.handle);
            this.handle = 0L;
        }
    }

    public void write(Kryo kryo, Output output) {
        try {
            byte[] byteArray = toByteArray();
            output.writeInt(byteArray.length);
            output.writeInt(this.version);
            output.write(byteArray);
        } catch (XGBoostError e) {
            logger.error(e.getMessage(), e);
        }
    }

    public void read(Kryo kryo, Input input) {
        try {
            init(null);
            int readInt = input.readInt();
            this.version = input.readInt();
            byte[] bArr = new byte[readInt];
            input.readBytes(bArr);
            XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bArr));
        } catch (XGBoostError e) {
            logger.error(e.getMessage(), e);
        }
    }
}
