/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.reduce.custom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

public class BatchMmul
extends DynamicCustomOp {
    protected int transposeA;
    protected int transposeB;
    protected int batchSize;
    protected int M;
    protected int N;
    protected int K;

    public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) {
        this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB);
    }

    public BatchMmul(SameDiff sameDiff, SDVariable[] matrices, boolean transposeA, boolean transposeB) {
        super(null, sameDiff, ArrayUtils.addAll(new SDVariable[]{sameDiff.var(Nd4j.ones(matrices[0].dataType(), matrices.length / 2)), sameDiff.var(Nd4j.zeros(matrices[1].dataType(), matrices.length / 2))}, matrices));
        Preconditions.checkState(matrices.length % 2 == 0, "The number of provided matrices needsto be divisible by two.");
        this.batchSize = matrices.length / 2;
        SDVariable firstMatrix = matrices[0];
        long[] firstShape = firstMatrix.getShape();
        for (int i = 0; i < this.batchSize; ++i) {
            Preconditions.checkState(Arrays.equals(firstShape, matrices[i].getShape()));
        }
        SDVariable lastMatrix = matrices[2 * this.batchSize - 1];
        long[] lastShape = lastMatrix.getShape();
        for (int i = this.batchSize; i < 2 * this.batchSize; ++i) {
            Preconditions.checkState(Arrays.equals(lastShape, matrices[i].getShape()));
        }
        this.transposeA = transposeA ? 1 : 0;
        this.transposeB = transposeB ? 1 : 0;
        this.M = transposeA ? (int)firstShape[1] : (int)firstShape[0];
        this.N = transposeA ? (int)firstShape[0] : (int)firstShape[1];
        this.K = transposeB ? (int)lastShape[0] : (int)lastShape[1];
        this.addArgs();
    }

    public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB) {
        super(ArrayUtils.addAll(matricesA, matricesB), null);
        this.batchSize = matricesA.length;
        this.transposeA = transposeA ? 1 : 0;
        this.transposeB = transposeB ? 1 : 0;
        long[] firstShape = matricesA[0].shape();
        long[] lastShape = matricesB[0].shape();
        this.M = transposeA ? (int)firstShape[1] : (int)firstShape[0];
        this.N = transposeA ? (int)firstShape[0] : (int)firstShape[1];
        this.K = transposeB ? (int)lastShape[0] : (int)lastShape[1];
        this.addArgs();
    }

    @Override
    public int getNumOutputs() {
        return this.batchSize;
    }

    public void addArgs() {
        this.addIArgument(this.transposeA, this.transposeB, this.M, this.K, this.N, this.M, this.K, this.N, this.batchSize);
    }

    public BatchMmul() {
    }

    @Override
    public String opName() {
        return "batched_gemm";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> grads) {
        SDVariable[] dLdOut = grads.toArray(new SDVariable[grads.size()]);
        SDVariable[] allArgs = this.args();
        SDVariable[] matricesA = Arrays.copyOfRange(allArgs, 0, this.batchSize);
        SDVariable[] matricesB = Arrays.copyOfRange(allArgs, this.batchSize, 2 * this.batchSize);
        SDVariable[] dLdx = this.sameDiff.batchMmul(dLdOut, matricesB, false, this.transposeB == 1);
        SDVariable[] dLdy = this.sameDiff.batchMmul(matricesA, dLdOut, this.transposeA == 1, false);
        ArrayList<SDVariable> ret = new ArrayList<SDVariable>();
        Collections.addAll(ret, dLdx);
        Collections.addAll(ret, dLdy);
        return ret;
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        ArrayList<DataType> out = new ArrayList<DataType>();
        for (int i = 0; i < dataTypes.size() - 2; ++i) {
            Preconditions.checkState(dataTypes.get(i).isFPType(), "Inputs to batch mmul op must all be a floating point type: got %s", dataTypes);
            if (i % 2 != 0) continue;
            out.add(dataTypes.get(i));
        }
        return out;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BatchMmul)) {
            return false;
        }
        BatchMmul other = (BatchMmul)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.transposeA != other.transposeA) {
            return false;
        }
        if (this.transposeB != other.transposeB) {
            return false;
        }
        if (this.batchSize != other.batchSize) {
            return false;
        }
        if (this.M != other.M) {
            return false;
        }
        if (this.N != other.N) {
            return false;
        }
        return this.K == other.K;
    }

    protected boolean canEqual(Object other) {
        return other instanceof BatchMmul;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.transposeA;
        result = result * 59 + this.transposeB;
        result = result * 59 + this.batchSize;
        result = result * 59 + this.M;
        result = result * 59 + this.N;
        result = result * 59 + this.K;
        return result;
    }
}

