package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/ExternalCheckpointManager.class */
public class ExternalCheckpointManager {
    private Log logger = LogFactory.getLog("ExternalCheckpointManager");
    private String modelSuffix = ".ubj";
    private Path checkpointPath;
    private FileSystem fs;

    public ExternalCheckpointManager(String str, FileSystem fileSystem) throws XGBoostError {
        if (str == null || str.isEmpty()) {
            throw new XGBoostError("cannot create ExternalCheckpointManager with null or empty checkpoint path");
        }
        this.checkpointPath = new Path(str);
        this.fs = fileSystem;
    }

    private String getPath(int i) {
        return this.checkpointPath.toUri().getPath() + "/" + i + this.modelSuffix;
    }

    private List<Integer> getExistingVersions() throws IOException {
        return !this.fs.exists(this.checkpointPath) ? new ArrayList() : (List) Arrays.stream(this.fs.listStatus(this.checkpointPath)).map(fileStatus -> {
            return fileStatus.getPath().getName();
        }).filter(str -> {
            return str.endsWith(this.modelSuffix);
        }).map(str2 -> {
            return Integer.valueOf(str2.substring(0, str2.length() - this.modelSuffix.length()));
        }).collect(Collectors.toList());
    }

    private Integer latest(List<Integer> list) {
        return list.stream().max(Comparator.comparing((v0) -> {
            return Integer.valueOf(v0);
        })).get();
    }

    public void cleanPath() throws IOException {
        this.fs.delete(this.checkpointPath, true);
    }

    public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
        List<Integer> existingVersions = getExistingVersions();
        if (existingVersions.size() <= 0) {
            return null;
        }
        String path = getPath(latest(existingVersions).intValue());
        FSDataInputStream open = this.fs.open(new Path(path));
        this.logger.info("loaded checkpoint from " + path);
        return XGBoost.loadModel((InputStream) open);
    }

    public void updateCheckpoint(Booster booster) throws IOException, XGBoostError {
        List list = (List) getExistingVersions().stream().map((v1) -> {
            return getPath(v1);
        }).collect(Collectors.toList());
        Integer valueOf = Integer.valueOf(booster.getNumBoostedRound() - 1);
        String path = getPath(valueOf.intValue());
        String str = path + "-" + UUID.randomUUID();
        FSDataOutputStream create = this.fs.create(new Path(str), true);
        try {
            booster.saveModel((OutputStream) create);
            this.fs.rename(new Path(str), new Path(path));
            this.logger.info("saving checkpoint with version " + valueOf);
            list.stream().forEach(str2 -> {
                try {
                    this.fs.delete(new Path(str2), true);
                } catch (IOException e) {
                    this.logger.error("failed to delete outdated checkpoint at " + str2, e);
                }
            });
            if (create != null) {
                create.close();
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void cleanUpHigherVersions(int i) throws IOException {
        getExistingVersions().stream().filter(num -> {
            return num.intValue() > i;
        }).forEach(num2 -> {
            try {
                this.fs.delete(new Path(getPath(num2.intValue())), true);
            } catch (IOException e) {
                this.logger.error("failed to clean checkpoint from other training instance", e);
            }
        });
    }

    public List<Integer> getCheckpointRounds(int i, int i2, int i3) throws IOException {
        int i4 = i + i3;
        int i5 = i4 - 1;
        if (i4 - 1 < 0) {
            throw new IllegalArgumentException("Inavlid `numOfRounds`.");
        }
        ArrayList arrayList = new ArrayList();
        if (i2 > 0) {
            int i6 = i;
            while (true) {
                int i7 = i6;
                if (i7 >= i4) {
                    break;
                }
                arrayList.add(Integer.valueOf(i7));
                i6 = i7 + i2;
            }
        }
        if (!arrayList.contains(Integer.valueOf(i5))) {
            arrayList.add(Integer.valueOf(i5));
        }
        return arrayList;
    }
}
