/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.lite;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.tensorflow.lite.Delegate;
import org.tensorflow.lite.InterpreterImpl;
import org.tensorflow.lite.NativeSignatureRunnerWrapper;
import org.tensorflow.lite.TensorFlowLite;
import org.tensorflow.lite.TensorImpl;
import org.tensorflow.lite.XnnpackDelegate;
import org.tensorflow.lite.nnapi.NnApiDelegate;

class NativeInterpreterWrapper
implements AutoCloseable {
    private static final int ERROR_BUFFER_SIZE = 512;
    long errorHandle;
    long interpreterHandle;
    private long modelHandle;
    private long cancellationFlagHandle = 0L;
    private long inferenceDurationNanoseconds = -1L;
    private ByteBuffer modelByteBuffer;
    private Map<String, Integer> inputsIndexes;
    private Map<String, Integer> outputsIndexes;
    private Map<Integer, Integer> tensorToInputsIndexes;
    private Map<Integer, Integer> tensorToOutputsIndexes;
    private Map<String, NativeSignatureRunnerWrapper> signatureRunnerMap;
    private TensorImpl[] inputTensors;
    private TensorImpl[] outputTensors;
    private boolean isMemoryAllocated = false;
    private boolean originalGraphHasUnresolvedFlexOp = false;
    private final List<Delegate> delegates = new ArrayList<Delegate>();
    private final List<AutoCloseable> ownedDelegates = new ArrayList<AutoCloseable>();

    NativeInterpreterWrapper(String modelPath) {
        this(modelPath, null);
    }

    NativeInterpreterWrapper(ByteBuffer byteBuffer) {
        this(byteBuffer, null);
    }

    NativeInterpreterWrapper(String modelPath, InterpreterImpl.Options options) {
        TensorFlowLite.init();
        long errorHandle = NativeInterpreterWrapper.createErrorReporter(512);
        long modelHandle = NativeInterpreterWrapper.createModel(modelPath, errorHandle);
        this.init(errorHandle, modelHandle, options);
    }

    NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options) {
        TensorFlowLite.init();
        if (buffer == null || !(buffer instanceof MappedByteBuffer) && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder())) {
            throw new IllegalArgumentException("Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content.");
        }
        this.modelByteBuffer = buffer;
        long errorHandle = NativeInterpreterWrapper.createErrorReporter(512);
        long modelHandle = NativeInterpreterWrapper.createModelWithBuffer(this.modelByteBuffer, errorHandle);
        this.init(errorHandle, modelHandle, options);
    }

    private void init(long errorHandle, long modelHandle, InterpreterImpl.Options options) {
        if (options == null) {
            options = new InterpreterImpl.Options();
        }
        this.errorHandle = errorHandle;
        this.modelHandle = modelHandle;
        ArrayList<Long> delegateHandles = new ArrayList<Long>();
        this.interpreterHandle = NativeInterpreterWrapper.createInterpreter(modelHandle, errorHandle, options.getNumThreads(), delegateHandles);
        this.originalGraphHasUnresolvedFlexOp = NativeInterpreterWrapper.hasUnresolvedFlexOp(this.interpreterHandle);
        this.addDelegates(options);
        delegateHandles.ensureCapacity(this.delegates.size());
        for (Delegate delegate : this.delegates) {
            delegateHandles.add(delegate.getNativeHandle());
        }
        if (!delegateHandles.isEmpty()) {
            NativeInterpreterWrapper.delete(0L, 0L, this.interpreterHandle);
            this.interpreterHandle = NativeInterpreterWrapper.createInterpreter(modelHandle, errorHandle, options.getNumThreads(), delegateHandles);
        }
        if (options.allowFp16PrecisionForFp32 != null) {
            NativeInterpreterWrapper.allowFp16PrecisionForFp32(this.interpreterHandle, options.allowFp16PrecisionForFp32);
        }
        if (options.allowBufferHandleOutput != null) {
            NativeInterpreterWrapper.allowBufferHandleOutput(this.interpreterHandle, options.allowBufferHandleOutput);
        }
        if (options.isCancellable()) {
            this.cancellationFlagHandle = NativeInterpreterWrapper.createCancellationFlag(this.interpreterHandle);
        }
        this.inputTensors = new TensorImpl[NativeInterpreterWrapper.getInputCount(this.interpreterHandle)];
        this.outputTensors = new TensorImpl[NativeInterpreterWrapper.getOutputCount(this.interpreterHandle)];
        if (options.allowFp16PrecisionForFp32 != null) {
            NativeInterpreterWrapper.allowFp16PrecisionForFp32(this.interpreterHandle, options.allowFp16PrecisionForFp32);
        }
        if (options.allowBufferHandleOutput != null) {
            NativeInterpreterWrapper.allowBufferHandleOutput(this.interpreterHandle, options.allowBufferHandleOutput);
        }
        NativeInterpreterWrapper.allocateTensors(this.interpreterHandle, errorHandle);
        this.isMemoryAllocated = true;
    }

    @Override
    public void close() {
        int i;
        for (i = 0; i < this.inputTensors.length; ++i) {
            if (this.inputTensors[i] == null) continue;
            this.inputTensors[i].close();
            this.inputTensors[i] = null;
        }
        for (i = 0; i < this.outputTensors.length; ++i) {
            if (this.outputTensors[i] == null) continue;
            this.outputTensors[i].close();
            this.outputTensors[i] = null;
        }
        NativeInterpreterWrapper.delete(this.errorHandle, this.modelHandle, this.interpreterHandle);
        NativeInterpreterWrapper.deleteCancellationFlag(this.cancellationFlagHandle);
        this.errorHandle = 0L;
        this.modelHandle = 0L;
        this.interpreterHandle = 0L;
        this.cancellationFlagHandle = 0L;
        this.modelByteBuffer = null;
        this.inputsIndexes = null;
        this.outputsIndexes = null;
        this.isMemoryAllocated = false;
        this.delegates.clear();
        for (AutoCloseable ownedDelegate : this.ownedDelegates) {
            try {
                ownedDelegate.close();
            }
            catch (Exception e) {
                System.err.println("Failed to close flex delegate: " + e);
            }
        }
        this.ownedDelegates.clear();
    }

    public void runSignature(Map<String, Object> inputs, Map<String, Object> outputs, String signatureKey) {
        this.inferenceDurationNanoseconds = -1L;
        if (inputs == null || inputs.isEmpty()) {
            throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
        }
        if (outputs == null) {
            throw new IllegalArgumentException("Input error: Outputs should not be null.");
        }
        NativeSignatureRunnerWrapper signatureRunnerWrapper = this.getSignatureRunnerWrapper(signatureKey);
        int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
        if (subgraphIndex == 0) {
            this.initTensorIndexesMaps();
            Object[] inputsList = new Object[inputs.size()];
            for (Map.Entry<String, Object> entry : inputs.entrySet()) {
                inputsList[signatureRunnerWrapper.getInputIndex((String)entry.getKey())] = entry.getValue();
            }
            TreeMap<Integer, Object> outputsWithOutputIndex = new TreeMap<Integer, Object>();
            for (Map.Entry<String, Object> entry : outputs.entrySet()) {
                outputsWithOutputIndex.put(signatureRunnerWrapper.getOutputIndex(entry.getKey()), entry.getValue());
            }
            this.run(inputsList, outputsWithOutputIndex);
            return;
        }
        for (Map.Entry<String, Object> input : inputs.entrySet()) {
            TensorImpl tensorImpl = this.getInputTensor(input.getKey(), signatureKey);
            int[] nArray = tensorImpl.getInputShapeIfDifferent(input.getValue());
            if (nArray == null) continue;
            signatureRunnerWrapper.resizeInput(input.getKey(), nArray);
        }
        signatureRunnerWrapper.allocateTensorsIfNeeded();
        for (Map.Entry<String, Object> input : inputs.entrySet()) {
            signatureRunnerWrapper.getInputTensor(input.getKey()).setTo(input.getValue());
        }
        long inferenceStartNanos = System.nanoTime();
        signatureRunnerWrapper.invoke();
        long l = System.nanoTime() - inferenceStartNanos;
        for (Map.Entry<String, Object> output : outputs.entrySet()) {
            if (output.getValue() == null) continue;
            signatureRunnerWrapper.getOutputTensor(output.getKey()).copyTo(output.getValue());
        }
        this.inferenceDurationNanoseconds = l;
    }

    void run(Object[] inputs, Map<Integer, Object> outputs) {
        this.inferenceDurationNanoseconds = -1L;
        if (inputs == null || inputs.length == 0) {
            throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
        }
        if (outputs == null) {
            throw new IllegalArgumentException("Input error: Outputs should not be null.");
        }
        for (int i = 0; i < inputs.length; ++i) {
            TensorImpl tensor = this.getInputTensor(i);
            int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
            if (newShape == null) continue;
            this.resizeInput(i, newShape);
        }
        boolean allocatedTensors = this.allocateTensorsIfNeeded();
        for (int i = 0; i < inputs.length; ++i) {
            this.getInputTensor(i).setTo(inputs[i]);
        }
        long inferenceStartNanos = System.nanoTime();
        NativeInterpreterWrapper.run(this.interpreterHandle, this.errorHandle);
        long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
        if (allocatedTensors) {
            for (int i = 0; i < this.outputTensors.length; ++i) {
                if (this.outputTensors[i] == null) continue;
                this.outputTensors[i].refreshShape();
            }
        }
        for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
            if (output.getValue() == null) continue;
            this.getOutputTensor(output.getKey()).copyTo(output.getValue());
        }
        this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
    }

    void resizeInput(int idx, int[] dims) {
        this.resizeInput(idx, dims, false);
    }

    void resizeInput(int idx, int[] dims, boolean strict) {
        if (NativeInterpreterWrapper.resizeInput(this.interpreterHandle, this.errorHandle, idx, dims, strict)) {
            this.isMemoryAllocated = false;
            if (this.inputTensors[idx] != null) {
                this.inputTensors[idx].refreshShape();
            }
        }
    }

    void allocateTensors() {
        this.allocateTensorsIfNeeded();
    }

    private boolean allocateTensorsIfNeeded() {
        if (this.isMemoryAllocated) {
            return false;
        }
        this.isMemoryAllocated = true;
        NativeInterpreterWrapper.allocateTensors(this.interpreterHandle, this.errorHandle);
        for (int i = 0; i < this.outputTensors.length; ++i) {
            if (this.outputTensors[i] == null) continue;
            this.outputTensors[i].refreshShape();
        }
        return true;
    }

    int getInputIndex(String name) {
        if (this.inputsIndexes == null) {
            String[] names = NativeInterpreterWrapper.getInputNames(this.interpreterHandle);
            this.inputsIndexes = new HashMap<String, Integer>();
            if (names != null) {
                for (int i = 0; i < names.length; ++i) {
                    this.inputsIndexes.put(names[i], i);
                }
            }
        }
        if (this.inputsIndexes.containsKey(name)) {
            return this.inputsIndexes.get(name);
        }
        throw new IllegalArgumentException(String.format("Input error: '%s' is not a valid name for any input. Names of inputs and their indexes are %s", name, this.inputsIndexes));
    }

    private void initTensorIndexesMaps() {
        if (this.tensorToInputsIndexes != null) {
            return;
        }
        this.tensorToInputsIndexes = new HashMap<Integer, Integer>();
        this.tensorToOutputsIndexes = new HashMap<Integer, Integer>();
        int inputCount = this.getInputTensorCount();
        for (int i = 0; i < inputCount; ++i) {
            int tensorIndex = NativeInterpreterWrapper.getInputTensorIndex(this.interpreterHandle, i);
            this.tensorToInputsIndexes.put(tensorIndex, i);
        }
        int outputCount = this.getOutputTensorCount();
        for (int i = 0; i < outputCount; ++i) {
            int tensorIndex = NativeInterpreterWrapper.getOutputTensorIndex(this.interpreterHandle, i);
            this.tensorToOutputsIndexes.put(tensorIndex, i);
        }
    }

    int getOutputIndex(String name) {
        if (this.outputsIndexes == null) {
            String[] names = NativeInterpreterWrapper.getOutputNames(this.interpreterHandle);
            this.outputsIndexes = new HashMap<String, Integer>();
            if (names != null) {
                for (int i = 0; i < names.length; ++i) {
                    this.outputsIndexes.put(names[i], i);
                }
            }
        }
        if (this.outputsIndexes.containsKey(name)) {
            return this.outputsIndexes.get(name);
        }
        throw new IllegalArgumentException(String.format("Input error: '%s' is not a valid name for any output. Names of outputs and their indexes are %s", name, this.outputsIndexes));
    }

    Long getLastNativeInferenceDurationNanoseconds() {
        return this.inferenceDurationNanoseconds < 0L ? null : Long.valueOf(this.inferenceDurationNanoseconds);
    }

    int getInputTensorCount() {
        return this.inputTensors.length;
    }

    TensorImpl getInputTensor(int index) {
        if (index < 0 || index >= this.inputTensors.length) {
            throw new IllegalArgumentException("Invalid input Tensor index: " + index);
        }
        TensorImpl inputTensor = this.inputTensors[index];
        if (inputTensor == null) {
            inputTensor = this.inputTensors[index] = TensorImpl.fromIndex(this.interpreterHandle, NativeInterpreterWrapper.getInputTensorIndex(this.interpreterHandle, index));
        }
        return inputTensor;
    }

    TensorImpl getInputTensor(String inputName, String signatureKey) {
        if (inputName == null) {
            throw new IllegalArgumentException("Invalid input tensor name provided (null)");
        }
        NativeSignatureRunnerWrapper signatureRunnerWrapper = this.getSignatureRunnerWrapper(signatureKey);
        int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
        if (subgraphIndex > 0) {
            return signatureRunnerWrapper.getInputTensor(inputName);
        }
        int inputIndex = signatureRunnerWrapper.getInputIndex(inputName);
        return this.getInputTensor(inputIndex);
    }

    public String[] getSignatureKeys() {
        return NativeInterpreterWrapper.getSignatureKeys(this.interpreterHandle);
    }

    String[] getSignatureInputs(String signatureKey) {
        return this.getSignatureRunnerWrapper(signatureKey).inputNames();
    }

    String[] getSignatureOutputs(String signatureKey) {
        return this.getSignatureRunnerWrapper(signatureKey).outputNames();
    }

    int getOutputTensorCount() {
        return this.outputTensors.length;
    }

    TensorImpl getOutputTensor(int index) {
        if (index < 0 || index >= this.outputTensors.length) {
            throw new IllegalArgumentException("Invalid output Tensor index: " + index);
        }
        TensorImpl outputTensor = this.outputTensors[index];
        if (outputTensor == null) {
            outputTensor = this.outputTensors[index] = TensorImpl.fromIndex(this.interpreterHandle, NativeInterpreterWrapper.getOutputTensorIndex(this.interpreterHandle, index));
        }
        return outputTensor;
    }

    TensorImpl getOutputTensor(String outputName, String signatureKey) {
        if (outputName == null) {
            throw new IllegalArgumentException("Invalid output tensor name provided (null)");
        }
        NativeSignatureRunnerWrapper signatureRunnerWrapper = this.getSignatureRunnerWrapper(signatureKey);
        int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
        if (subgraphIndex > 0) {
            return signatureRunnerWrapper.getOutputTensor(outputName);
        }
        int outputIndex = signatureRunnerWrapper.getOutputIndex(outputName);
        return this.getOutputTensor(outputIndex);
    }

    int getExecutionPlanLength() {
        return NativeInterpreterWrapper.getExecutionPlanLength(this.interpreterHandle);
    }

    void setCancelled(boolean value) {
        if (this.cancellationFlagHandle == 0L) {
            throw new IllegalStateException("Cannot cancel the inference. Have you called InterpreterApi.Options.setCancellable?");
        }
        NativeInterpreterWrapper.setCancelled(this.interpreterHandle, this.cancellationFlagHandle, value);
    }

    private void addDelegates(InterpreterImpl.Options options) {
        Delegate optionalFlexDelegate;
        if (this.originalGraphHasUnresolvedFlexOp && (optionalFlexDelegate = NativeInterpreterWrapper.maybeCreateFlexDelegate(options.getDelegates())) != null) {
            this.ownedDelegates.add((AutoCloseable)optionalFlexDelegate);
            this.delegates.add(optionalFlexDelegate);
        }
        this.delegates.addAll(options.getDelegates());
        if (options.getUseNNAPI()) {
            NnApiDelegate optionalNnApiDelegate = new NnApiDelegate();
            this.ownedDelegates.add(optionalNnApiDelegate);
            this.delegates.add(optionalNnApiDelegate);
        }
        this.maybeAddXnnpackDelegate(options);
    }

    private void maybeAddXnnpackDelegate(InterpreterImpl.Options options) {
        int applyXNNPACKMode = -1;
        if (options.useXNNPACK != null) {
            int n = applyXNNPACKMode = options.useXNNPACK != false ? 1 : 0;
        }
        if (applyXNNPACKMode == 1) {
            XnnpackDelegate xnnpackDelegate = NativeInterpreterWrapper.createXNNPACKDelegate(this.interpreterHandle, this.errorHandle, applyXNNPACKMode, options.getNumThreads());
            this.delegates.add(xnnpackDelegate);
        }
    }

    private NativeSignatureRunnerWrapper getSignatureRunnerWrapper(String signatureKey) {
        if (this.signatureRunnerMap == null) {
            this.signatureRunnerMap = new HashMap<String, NativeSignatureRunnerWrapper>();
        }
        if (!this.signatureRunnerMap.containsKey(signatureKey)) {
            this.signatureRunnerMap.put(signatureKey, new NativeSignatureRunnerWrapper(this.interpreterHandle, this.errorHandle, signatureKey));
        }
        return this.signatureRunnerMap.get(signatureKey);
    }

    private static Delegate maybeCreateFlexDelegate(List<Delegate> delegates) {
        try {
            Class<?> clazz = Class.forName("org.tensorflow.lite.flex.FlexDelegate");
            for (Delegate delegate : delegates) {
                if (!clazz.isInstance(delegate)) continue;
                return null;
            }
            return (Delegate)clazz.getConstructor(new Class[0]).newInstance(new Object[0]);
        }
        catch (Exception e) {
            return null;
        }
    }

    private static native void run(long var0, long var2);

    private static native boolean resizeInput(long var0, long var2, int var4, int[] var5, boolean var6);

    private static native long allocateTensors(long var0, long var2);

    private static native String[] getSignatureKeys(long var0);

    private static native void setCancelled(long var0, long var2, boolean var4);

    private static native int getOutputDataType(long var0, int var2);

    private static native boolean hasUnresolvedFlexOp(long var0);

    private static native int getInputTensorIndex(long var0, int var2);

    private static native int getOutputTensorIndex(long var0, int var2);

    private static native int getInputCount(long var0);

    private static native int getOutputCount(long var0);

    private static native int getExecutionPlanLength(long var0);

    private static native String[] getInputNames(long var0);

    private static native String[] getOutputNames(long var0);

    private static native void allowFp16PrecisionForFp32(long var0, boolean var2);

    private static native void allowBufferHandleOutput(long var0, boolean var2);

    private static native XnnpackDelegate createXNNPACKDelegate(long var0, long var2, int var4, int var5);

    private static native long createErrorReporter(int var0);

    private static native long createModel(String var0, long var1);

    private static native long createModelWithBuffer(ByteBuffer var0, long var1);

    private static native long createInterpreter(long var0, long var2, int var4, List<Long> var5);

    private static native void resetVariableTensors(long var0, long var2);

    private static native long createCancellationFlag(long var0);

    private static native long deleteCancellationFlag(long var0);

    private static native void delete(long var0, long var2, long var4);
}

