package ai.djl.inference;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.streaming.StreamingBlock;
import ai.djl.inference.streaming.StreamingTranslator;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.ndarray.LazyNDArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/inference/Predictor.class */
public class Predictor<I, O> implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
    protected Translator<I, O> translator;
    protected long timestamp;
    protected boolean prepared;
    protected Model model;
    protected NDManager manager;
    protected Metrics metrics;
    protected Block block;
    protected ParameterStore parameterStore;
    protected Dimension dimension;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ai/djl/inference/Predictor$PredictorContext.class */
    public class PredictorContext implements TranslatorContext {
        private NDManager ctxManager;
        private Map<String, Object> attachments;

        public PredictorContext() {
            this.ctxManager = Predictor.this.manager.newSubManager();
            this.ctxManager.setName("predictor ctx");
            this.attachments = new ConcurrentHashMap();
        }

        @Override // ai.djl.translate.TranslatorContext
        public Model getModel() {
            return Predictor.this.model;
        }

        @Override // ai.djl.translate.TranslatorContext
        public NDManager getNDManager() {
            return this.ctxManager;
        }

        @Override // ai.djl.translate.TranslatorContext
        public NDManager getPredictorManager() {
            return Predictor.this.manager;
        }

        @Override // ai.djl.translate.TranslatorContext
        public Block getBlock() {
            return Predictor.this.block;
        }

        @Override // ai.djl.translate.TranslatorContext
        public Metrics getMetrics() {
            return Predictor.this.metrics;
        }

        @Override // ai.djl.translate.TranslatorContext, java.lang.AutoCloseable
        public void close() {
            this.ctxManager.close();
        }

        @Override // ai.djl.translate.TranslatorContext
        public Object getAttachment(String str) {
            return this.attachments.get(str);
        }

        @Override // ai.djl.translate.TranslatorContext
        public void setAttachment(String str, Object obj) {
            this.attachments.put(str, obj);
        }
    }

    public Predictor(Model model, Translator<I, O> translator, Device device, boolean z) {
        z = device.equals(model.getNDManager().getDevice()) ? z : true;
        this.model = model;
        this.manager = model.getNDManager().newSubManager(device);
        this.manager.setName("predictor");
        this.translator = translator;
        this.block = model.getBlock();
        this.parameterStore = new ParameterStore(this.manager, z);
        this.dimension = new Dimension("Model", model.getProperty("metric_dimension", "model"));
    }

    public O predict(I i) throws TranslateException {
        return batchPredict(Collections.singletonList(i)).get(0);
    }

    protected NDList predictInternal(TranslatorContext translatorContext, NDList nDList) throws TranslateException {
        logger.trace("Predictor input data: {}", nDList);
        return this.block.forward(this.parameterStore, nDList, false);
    }

    public List<O> batchPredict(List<I> list) throws TranslateException {
        try {
            PredictorContext predictorContext = new PredictorContext();
            try {
                if (!this.prepared) {
                    this.translator.prepare(predictorContext);
                    this.prepared = true;
                }
                if (this.translator.getBatchifier() != null) {
                    int size = list.size();
                    this.timestamp = System.nanoTime();
                    long j = this.timestamp;
                    NDList batchProcessInput = this.translator.batchProcessInput(predictorContext, list);
                    preprocessEnd(batchProcessInput, size);
                    NDList predictInternal = predictInternal(predictorContext, batchProcessInput);
                    predictEnd(predictInternal, size);
                    List<O> batchProcessOutput = this.translator.batchProcessOutput(predictorContext, predictInternal);
                    postProcessEnd(j, size);
                    predictorContext.close();
                    return batchProcessOutput;
                }
                ArrayList arrayList = new ArrayList(list.size());
                for (I i : list) {
                    this.timestamp = System.nanoTime();
                    long j2 = this.timestamp;
                    NDList processInput = this.translator.processInput(predictorContext, i);
                    preprocessEnd(processInput, 1);
                    NDList predictInternal2 = predictInternal(predictorContext, processInput);
                    predictEnd(predictInternal2, 1);
                    arrayList.add(this.translator.processOutput(predictorContext, predictInternal2));
                    postProcessEnd(j2, 1);
                }
                predictorContext.close();
                return arrayList;
            } finally {
            }
        } catch (TranslateException e) {
            throw e;
        } catch (Exception e2) {
            throw new TranslateException(e2);
        }
    }

    public StreamingTranslator.StreamOutput<O> streamingPredict(I i) throws TranslateException {
        String streamingSupportError = streamingSupportError();
        if (streamingSupportError != null) {
            throw new IllegalStateException(streamingSupportError);
        }
        StreamingBlock streamingBlock = (StreamingBlock) this.block;
        StreamingTranslator streamingTranslator = (StreamingTranslator) this.translator;
        try {
            PredictorContext predictorContext = new PredictorContext();
            if (!this.prepared) {
                this.translator.prepare(predictorContext);
                this.prepared = true;
            }
            if (this.translator.getBatchifier() == null) {
                Stream<NDList> forwardStream = streamingBlock.forwardStream(this.parameterStore, this.translator.processInput(predictorContext, i), false);
                Objects.requireNonNull(predictorContext);
                return streamingTranslator.processStreamOutput(predictorContext, (Stream) forwardStream.onClose(predictorContext::close));
            }
            Stream<R> map = streamingBlock.forwardStream(this.parameterStore, processInputs(predictorContext, Collections.singletonList(i)), false).map(nDList -> {
                NDList[] unbatchify = this.translator.getBatchifier().unbatchify(nDList);
                if (unbatchify.length != 1) {
                    throw new IllegalStateException("Unexpected number of outputs from model");
                }
                return unbatchify[0];
            });
            Objects.requireNonNull(predictorContext);
            return streamingTranslator.processStreamOutput(predictorContext, (Stream) map.onClose(predictorContext::close));
        } catch (TranslateException e) {
            throw e;
        } catch (Exception e2) {
            throw new TranslateException(e2);
        }
    }

    public boolean supportsStreaming() {
        return streamingSupportError() == null;
    }

    private String streamingSupportError() {
        if (!(this.block instanceof StreamingBlock)) {
            return "streamingPredict() can only be called with a StreamingBlock";
        }
        if (this.translator instanceof StreamingTranslator) {
            return null;
        }
        return "streamingPredict() can only be called with a StreamingTranslator";
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    private void waitToRead(NDList nDList) {
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            if (next instanceof LazyNDArray) {
                ((LazyNDArray) next).waitToRead();
            }
        }
    }

    private NDList processInputs(TranslatorContext translatorContext, List<I> list) throws Exception {
        int size = list.size();
        NDList[] nDListArr = new NDList[size];
        for (int i = 0; i < size; i++) {
            nDListArr[i] = this.translator.processInput(translatorContext, list.get(i));
        }
        return this.translator.getBatchifier().batchify(nDListArr);
    }

    private void preprocessEnd(NDList nDList, int i) {
        if (this.metrics != null) {
            waitToRead(nDList);
            long nanoTime = System.nanoTime();
            long j = ((nanoTime - this.timestamp) / 1000) / i;
            this.timestamp = nanoTime;
            this.metrics.addMetric("Preprocess", Long.valueOf(j), Unit.MICROSECONDS, this.dimension);
        }
    }

    private void predictEnd(NDList nDList, int i) {
        if (this.metrics != null) {
            waitToRead(nDList);
            long nanoTime = System.nanoTime();
            long j = ((nanoTime - this.timestamp) / 1000) / i;
            this.timestamp = nanoTime;
            this.metrics.addMetric("Inference", Long.valueOf(j), Unit.MICROSECONDS, this.dimension);
        }
    }

    private void postProcessEnd(long j, int i) {
        if (this.metrics != null) {
            long nanoTime = System.nanoTime();
            long j2 = ((nanoTime - this.timestamp) / 1000) / i;
            this.timestamp = nanoTime;
            this.metrics.addMetric("Postprocess", Long.valueOf(j2), Unit.MICROSECONDS, this.dimension);
            this.metrics.addMetric("Prediction", Long.valueOf((nanoTime - j) / 1000), Unit.MICROSECONDS, this.dimension);
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Predictor for {} was not closed explicitly.", this.model.getName());
            }
            close();
        }
        super.finalize();
    }
}
