package ai.djl.ndarray;

import ai.djl.util.Preconditions;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/ndarray/NDArrays.class */
public final class NDArrays {
    private NDArrays() {
    }

    private static void checkInputs(NDArray[] nDArrayArr) {
        if (nDArrayArr == null || nDArrayArr.length < 2) {
            throw new IllegalArgumentException("Passed in arrays must have at least one element");
        }
        if (nDArrayArr.length > 2 && Arrays.stream(nDArrayArr).skip(1L).anyMatch(nDArray -> {
            return !nDArrayArr[0].shapeEquals(nDArray);
        })) {
            throw new IllegalArgumentException("The shape of all inputs must be the same");
        }
    }

    public static boolean contentEquals(NDArray nDArray, Number number) {
        if (nDArray == null) {
            return false;
        }
        return nDArray.contentEquals(number);
    }

    public static boolean contentEquals(NDArray nDArray, NDArray nDArray2) {
        return nDArray.contentEquals(nDArray2);
    }

    public static boolean shapeEquals(NDArray nDArray, NDArray nDArray2) {
        return nDArray.shapeEquals(nDArray2);
    }

    public static boolean allClose(NDArray nDArray, NDArray nDArray2) {
        return nDArray.allClose(nDArray2);
    }

    public static boolean allClose(NDArray nDArray, NDArray nDArray2, double d, double d2, boolean z) {
        return nDArray.allClose(nDArray2, d, d2, z);
    }

    public static NDArray eq(NDArray nDArray, Number number) {
        return nDArray.eq(number);
    }

    public static NDArray eq(Number number, NDArray nDArray) {
        return eq(nDArray, number);
    }

    public static NDArray eq(NDArray nDArray, NDArray nDArray2) {
        return nDArray.eq(nDArray2);
    }

    public static NDArray neq(NDArray nDArray, Number number) {
        return nDArray.neq(number);
    }

    public static NDArray neq(Number number, NDArray nDArray) {
        return neq(nDArray, number);
    }

    public static NDArray neq(NDArray nDArray, NDArray nDArray2) {
        return nDArray.neq(nDArray2);
    }

    public static NDArray gt(NDArray nDArray, Number number) {
        return nDArray.gt(number);
    }

    public static NDArray gt(Number number, NDArray nDArray) {
        return nDArray.lt(number);
    }

    public static NDArray gt(NDArray nDArray, NDArray nDArray2) {
        return nDArray.gt(nDArray2);
    }

    public static NDArray gte(NDArray nDArray, Number number) {
        return nDArray.gte(number);
    }

    public static NDArray gte(Number number, NDArray nDArray) {
        return nDArray.lte(number);
    }

    public static NDArray gte(NDArray nDArray, NDArray nDArray2) {
        return nDArray.gte(nDArray2);
    }

    public static NDArray lt(NDArray nDArray, Number number) {
        return nDArray.lt(number);
    }

    public static NDArray lt(Number number, NDArray nDArray) {
        return nDArray.gt(number);
    }

    public static NDArray lt(NDArray nDArray, NDArray nDArray2) {
        return nDArray.lt(nDArray2);
    }

    public static NDArray lte(NDArray nDArray, Number number) {
        return nDArray.lte(number);
    }

    public static NDArray lte(Number number, NDArray nDArray) {
        return nDArray.gte(number);
    }

    public static NDArray lte(NDArray nDArray, NDArray nDArray2) {
        return nDArray.lte(nDArray2);
    }

    public static NDArray where(NDArray nDArray, NDArray nDArray2, NDArray nDArray3) {
        return nDArray2.getNDArrayInternal().where(nDArray, nDArray3);
    }

    public static NDArray maximum(NDArray nDArray, Number number) {
        return nDArray.maximum(number);
    }

    public static NDArray maximum(Number number, NDArray nDArray) {
        return maximum(nDArray, number);
    }

    public static NDArray maximum(NDArray nDArray, NDArray nDArray2) {
        return nDArray.maximum(nDArray2);
    }

    public static NDArray minimum(NDArray nDArray, Number number) {
        return nDArray.minimum(number);
    }

    public static NDArray minimum(Number number, NDArray nDArray) {
        return minimum(nDArray, number);
    }

    public static NDArray minimum(NDArray nDArray, NDArray nDArray2) {
        return nDArray.minimum(nDArray2);
    }

    public static NDArray booleanMask(NDArray nDArray, NDArray nDArray2) {
        return booleanMask(nDArray, nDArray2, 0);
    }

    public static NDArray booleanMask(NDArray nDArray, NDArray nDArray2, int i) {
        return nDArray.booleanMask(nDArray2, i);
    }

    public static NDArray sequenceMask(NDArray nDArray, NDArray nDArray2, float f) {
        return nDArray.sequenceMask(nDArray2, f);
    }

    public static NDArray sequenceMask(NDArray nDArray, NDArray nDArray2) {
        return nDArray.sequenceMask(nDArray2);
    }

    public static NDArray add(NDArray nDArray, Number number) {
        return nDArray.add(number);
    }

    public static NDArray add(Number number, NDArray nDArray) {
        return nDArray.add(number);
    }

    public static NDArray add(NDArray... nDArrayArr) {
        checkInputs(nDArrayArr);
        if (nDArrayArr.length == 2) {
            return nDArrayArr[0].add(nDArrayArr[1]);
        }
        NDArray stack = stack(new NDList(nDArrayArr));
        try {
            NDArray sum = stack.sum(new int[]{0});
            if (stack != null) {
                stack.close();
            }
            return sum;
        } catch (Throwable th) {
            if (stack != null) {
                try {
                    stack.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static NDArray sub(NDArray nDArray, Number number) {
        return nDArray.sub(number);
    }

    public static NDArray sub(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rsub(number);
    }

    public static NDArray sub(NDArray nDArray, NDArray nDArray2) {
        return nDArray.sub(nDArray2);
    }

    public static NDArray mul(NDArray nDArray, Number number) {
        return nDArray.mul(number);
    }

    public static NDArray mul(Number number, NDArray nDArray) {
        return nDArray.mul(number);
    }

    public static NDArray mul(NDArray... nDArrayArr) {
        checkInputs(nDArrayArr);
        if (nDArrayArr.length == 2) {
            return nDArrayArr[0].mul(nDArrayArr[1]);
        }
        NDArray stack = stack(new NDList(nDArrayArr));
        try {
            NDArray prod = stack.prod(new int[]{0});
            if (stack != null) {
                stack.close();
            }
            return prod;
        } catch (Throwable th) {
            if (stack != null) {
                try {
                    stack.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static NDArray div(NDArray nDArray, Number number) {
        return nDArray.div(number);
    }

    public static NDArray div(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rdiv(number);
    }

    public static NDArray div(NDArray nDArray, NDArray nDArray2) {
        return nDArray.div(nDArray2);
    }

    public static NDArray mod(NDArray nDArray, Number number) {
        return nDArray.mod(number);
    }

    public static NDArray mod(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rmod(number);
    }

    public static NDArray mod(NDArray nDArray, NDArray nDArray2) {
        return nDArray.mod(nDArray2);
    }

    public static NDArray pow(NDArray nDArray, Number number) {
        return nDArray.pow(number);
    }

    public static NDArray pow(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rpow(number);
    }

    public static NDArray pow(NDArray nDArray, NDArray nDArray2) {
        return nDArray.pow(nDArray2);
    }

    public static NDArray addi(NDArray nDArray, Number number) {
        return nDArray.addi(number);
    }

    public static NDArray addi(Number number, NDArray nDArray) {
        return nDArray.addi(number);
    }

    public static NDArray addi(NDArray... nDArrayArr) {
        checkInputs(nDArrayArr);
        Arrays.stream(nDArrayArr).skip(1L).forEachOrdered(nDArray -> {
            nDArrayArr[0].addi(nDArray);
        });
        return nDArrayArr[0];
    }

    public static NDArray subi(NDArray nDArray, Number number) {
        return nDArray.subi(number);
    }

    public static NDArray subi(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rsubi(number);
    }

    public static NDArray subi(NDArray nDArray, NDArray nDArray2) {
        return nDArray.subi(nDArray2);
    }

    public static NDArray muli(NDArray nDArray, Number number) {
        return nDArray.muli(number);
    }

    public static NDArray muli(Number number, NDArray nDArray) {
        return nDArray.muli(number);
    }

    public static NDArray muli(NDArray... nDArrayArr) {
        checkInputs(nDArrayArr);
        Arrays.stream(nDArrayArr).skip(1L).forEachOrdered(nDArray -> {
            nDArrayArr[0].muli(nDArray);
        });
        return nDArrayArr[0];
    }

    public static NDArray divi(NDArray nDArray, Number number) {
        return nDArray.divi(number);
    }

    public static NDArray divi(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rdivi(number);
    }

    public static NDArray divi(NDArray nDArray, NDArray nDArray2) {
        return nDArray.divi(nDArray2);
    }

    public static NDArray modi(NDArray nDArray, Number number) {
        return nDArray.modi(number);
    }

    public static NDArray modi(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rmodi(number);
    }

    public static NDArray modi(NDArray nDArray, NDArray nDArray2) {
        return nDArray.modi(nDArray2);
    }

    public static NDArray powi(NDArray nDArray, Number number) {
        return nDArray.powi(number);
    }

    public static NDArray powi(Number number, NDArray nDArray) {
        return nDArray.getNDArrayInternal().rpowi(number);
    }

    public static NDArray powi(NDArray nDArray, NDArray nDArray2) {
        return nDArray.powi(nDArray2);
    }

    public static NDArray dot(NDArray nDArray, NDArray nDArray2) {
        return nDArray.dot(nDArray2);
    }

    public static NDArray matMul(NDArray nDArray, NDArray nDArray2) {
        return nDArray.matMul(nDArray2);
    }

    public static NDArray stack(NDList nDList) {
        return stack(nDList, 0);
    }

    public static NDArray stack(NDList nDList, int i) {
        Preconditions.checkArgument(nDList.size() > 0, "need at least one array to stack");
        return nDList.head().getNDArrayInternal().stack(nDList.subNDList(1), i);
    }

    public static NDArray concat(NDList nDList) {
        return concat(nDList, 0);
    }

    public static NDArray concat(NDList nDList, int i) {
        Preconditions.checkArgument(nDList.size() > 0, "need at least one array to concatenate");
        return nDList.size() == 1 ? nDList.singletonOrThrow().duplicate() : nDList.head().getNDArrayInternal().concat(nDList.subNDList(1), i);
    }

    public static NDArray logicalAnd(NDArray nDArray, NDArray nDArray2) {
        return nDArray.logicalAnd(nDArray2);
    }

    public static NDArray logicalOr(NDArray nDArray, NDArray nDArray2) {
        return nDArray.logicalOr(nDArray2);
    }

    public static NDArray logicalXor(NDArray nDArray, NDArray nDArray2) {
        return nDArray.logicalXor(nDArray2);
    }

    public static NDArray erfinv(NDArray nDArray) {
        return nDArray.erfinv();
    }

    public static NDArray erf(NDArray nDArray) {
        return nDArray.erf();
    }
}
