/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation.listeners;

import java.security.MessageDigest;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public class NonInplaceValidationListener
extends BaseListener {
    private static AtomicInteger useCounter = new AtomicInteger();
    private static AtomicInteger passCounter = new AtomicInteger();
    private static AtomicInteger failCounter = new AtomicInteger();
    protected INDArray[] opInputs;
    protected INDArray[] opInputsOrig;

    public NonInplaceValidationListener() {
        useCounter.getAndIncrement();
    }

    @Override
    public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext oc) {
        if (op.getOp().isInPlace()) {
            return;
        }
        if (op.getOp() instanceof Op) {
            Op o = (Op)((Object)op.getOp());
            if (oc.getInputArray(0) == null) {
                return;
            }
            if (oc.getInputArray(1) == null) {
                this.opInputsOrig = new INDArray[]{oc.getInputArray(0)};
                this.opInputs = new INDArray[]{oc.getInputArray(0).dup()};
            } else {
                this.opInputsOrig = new INDArray[]{oc.getInputArray(0), oc.getInputArray(1)};
                this.opInputs = new INDArray[]{oc.getInputArray(0).dup(), oc.getInputArray(1).dup()};
            }
        } else if (op.getOp() instanceof DynamicCustomOp) {
            List<INDArray> arr = oc.getInputArrays();
            this.opInputs = new INDArray[arr.size()];
            this.opInputsOrig = new INDArray[arr.size()];
            for (int i = 0; i < arr.size(); ++i) {
                this.opInputsOrig[i] = arr.get(i);
                this.opInputs[i] = arr.get(i).dup();
            }
        } else {
            throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
        }
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
        MessageDigest md;
        if (op.getOp().isInPlace()) {
            return;
        }
        try {
            md = MessageDigest.getInstance("MD5");
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
        for (int i = 0; i < this.opInputs.length; ++i) {
            byte[] hash2;
            if (this.opInputs[i].isEmpty()) continue;
            byte[] before = this.opInputs[i].data().asBytes();
            INDArray after = this.opInputsOrig[i];
            boolean dealloc = false;
            if (this.opInputs[i].ordering() != this.opInputsOrig[i].ordering() || Arrays.equals(this.opInputs[i].stride(), this.opInputsOrig[i].stride()) || this.opInputs[i].elementWiseStride() != this.opInputsOrig[i].elementWiseStride()) {
                after = this.opInputsOrig[i].dup();
                dealloc = true;
            }
            byte[] afterB = after.data().asBytes();
            byte[] hash1 = md.digest(before);
            boolean eq = Arrays.equals(hash1, hash2 = md.digest(afterB));
            if (eq) {
                passCounter.addAndGet(1);
            } else {
                failCounter.addAndGet(1);
            }
            Preconditions.checkState(eq, "Input array for non-inplace op was modified during execution for op %s - input %s", op.getOp().getClass(), (Object)i);
            if (dealloc && after.closeable()) {
                after.close();
            }
            if (!this.opInputs[i].closeable()) continue;
            this.opInputs[i].close();
        }
    }

    @Override
    public boolean isActive(Operation operation) {
        return true;
    }

    public static AtomicInteger getUseCounter() {
        return useCounter;
    }

    public static AtomicInteger getPassCounter() {
        return passCounter;
    }

    public static AtomicInteger getFailCounter() {
        return failCounter;
    }
}

