/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.factory.ops;

import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.PRelu;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

public class NDNN {
    public INDArray cReLU(INDArray x) {
        NDValidation.validateNumerical("CReLU", "x", x);
        return Nd4j.exec(new CReLU(x))[0];
    }

    public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int ... axis) {
        NDValidation.validateNumerical("batchNorm", "input", input);
        NDValidation.validateNumerical("batchNorm", "mean", mean);
        NDValidation.validateNumerical("batchNorm", "variance", variance);
        NDValidation.validateNumerical("batchNorm", "gamma", gamma);
        NDValidation.validateNumerical("batchNorm", "beta", beta);
        Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
        return Nd4j.exec(new BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0];
    }

    public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) {
        NDValidation.validateNumerical("biasAdd", "input", input);
        NDValidation.validateNumerical("biasAdd", "bias", bias);
        return Nd4j.exec(new BiasAdd(input, bias, nchw))[0];
    }

    public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray values, INDArray mask, boolean scaled) {
        NDValidation.validateNumerical("dotProductAttention", "queries", queries);
        NDValidation.validateNumerical("dotProductAttention", "keys", keys);
        NDValidation.validateNumerical("dotProductAttention", "values", values);
        NDValidation.validateNumerical("dotProductAttention", "mask", mask);
        return Nd4j.exec(new DotProductAttention(queries, keys, values, mask, scaled, false))[0];
    }

    public INDArray dropout(INDArray input, double inputRetainProbability) {
        NDValidation.validateNumerical("dropout", "input", input);
        return Nd4j.exec(new DropOut(input, inputRetainProbability));
    }

    public INDArray elu(INDArray x) {
        NDValidation.validateNumerical("elu", "x", x);
        return Nd4j.exec(new ELU(x))[0];
    }

    public INDArray gelu(INDArray x) {
        NDValidation.validateNumerical("gelu", "x", x);
        return Nd4j.exec(new GELU(x));
    }

    public INDArray hardSigmoid(INDArray x) {
        NDValidation.validateNumerical("hardSigmoid", "x", x);
        return Nd4j.exec(new HardSigmoid(x));
    }

    public INDArray hardTanh(INDArray x) {
        NDValidation.validateNumerical("hardTanh", "x", x);
        return Nd4j.exec(new HardTanh(x));
    }

    public INDArray hardTanhDerivative(INDArray x) {
        NDValidation.validateNumerical("hardTanhDerivative", "x", x);
        return Nd4j.exec(new HardTanhDerivative(x));
    }

    public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean channelsFirst, int ... dimensions) {
        NDValidation.validateNumerical("layerNorm", "input", input);
        NDValidation.validateNumerical("layerNorm", "gain", gain);
        NDValidation.validateNumerical("layerNorm", "bias", bias);
        Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
        return Nd4j.exec(new LayerNorm(input, gain, bias, channelsFirst, dimensions))[0];
    }

    public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst, int ... dimensions) {
        NDValidation.validateNumerical("layerNorm", "input", input);
        NDValidation.validateNumerical("layerNorm", "gain", gain);
        Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
        return Nd4j.exec(new LayerNorm(input, gain, null, channelsFirst, dimensions))[0];
    }

    public INDArray leakyRelu(INDArray x, double alpha) {
        NDValidation.validateNumerical("leakyRelu", "x", x);
        return Nd4j.exec(new LeakyReLU(x, alpha));
    }

    public INDArray leakyReluDerivative(INDArray x, double alpha) {
        NDValidation.validateNumerical("leakyReluDerivative", "x", x);
        return Nd4j.exec(new LeakyReLUDerivative(x, alpha));
    }

    public INDArray linear(INDArray input, INDArray weights, INDArray bias) {
        NDValidation.validateNumerical("linear", "input", input);
        NDValidation.validateNumerical("linear", "weights", weights);
        NDValidation.validateNumerical("linear", "bias", bias);
        return Nd4j.exec(new XwPlusB(input, weights, bias))[0];
    }

    public INDArray logSigmoid(INDArray x) {
        NDValidation.validateNumerical("logSigmoid", "x", x);
        return Nd4j.exec(new LogSigmoid(x));
    }

    public INDArray logSoftmax(INDArray x) {
        NDValidation.validateNumerical("logSoftmax", "x", x);
        return Nd4j.exec(new LogSoftMax(x))[0];
    }

    public INDArray logSoftmax(INDArray x, int dimension) {
        NDValidation.validateNumerical("logSoftmax", "x", x);
        return Nd4j.exec(new LogSoftMax(x, dimension))[0];
    }

    public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, INDArray values, INDArray Wq, INDArray Wk, INDArray Wv, INDArray Wo, INDArray mask, boolean scaled) {
        NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "values", values);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask);
        return Nd4j.exec(new MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0];
    }

    public INDArray pad(INDArray input, INDArray padding, PadMode PadMode2, double constant) {
        NDValidation.validateNumerical("pad", "input", input);
        NDValidation.validateNumerical("pad", "padding", padding);
        return Nd4j.exec(new Pad(input, padding, PadMode2, constant))[0];
    }

    public INDArray pad(INDArray input, INDArray padding, double constant) {
        NDValidation.validateNumerical("pad", "input", input);
        NDValidation.validateNumerical("pad", "padding", padding);
        return Nd4j.exec(new Pad(input, padding, PadMode.CONSTANT, constant))[0];
    }

    public INDArray preciseGelu(INDArray x) {
        NDValidation.validateNumerical("preciseGelu", "x", x);
        return Nd4j.exec(new PreciseGELU(x));
    }

    public INDArray prelu(INDArray input, INDArray alpha, int ... sharedAxes) {
        NDValidation.validateNumerical("prelu", "input", input);
        NDValidation.validateNumerical("prelu", "alpha", alpha);
        Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length);
        return Nd4j.exec(new PRelu(input, alpha, sharedAxes))[0];
    }

    public INDArray relu(INDArray x, double cutoff) {
        NDValidation.validateNumerical("relu", "x", x);
        return Nd4j.exec(new RectifiedLinear(x, cutoff));
    }

    public INDArray relu6(INDArray x, double cutoff) {
        NDValidation.validateNumerical("relu6", "x", x);
        return Nd4j.exec(new Relu6(x, cutoff));
    }

    public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) {
        NDValidation.validateNumerical("reluLayer", "input", input);
        NDValidation.validateNumerical("reluLayer", "weights", weights);
        NDValidation.validateNumerical("reluLayer", "bias", bias);
        return Nd4j.exec(new ReluLayer(input, weights, bias))[0];
    }

    public INDArray selu(INDArray x) {
        NDValidation.validateNumerical("selu", "x", x);
        return Nd4j.exec(new SELU(x));
    }

    public INDArray sigmoid(INDArray x) {
        NDValidation.validateNumerical("sigmoid", "x", x);
        return Nd4j.exec(new Sigmoid(x));
    }

    public INDArray sigmoidDerivative(INDArray x, INDArray wrt) {
        NDValidation.validateNumerical("sigmoidDerivative", "x", x);
        NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt);
        return Nd4j.exec(new SigmoidDerivative(x, wrt))[0];
    }

    public INDArray softmax(INDArray x, int dimension) {
        NDValidation.validateNumerical("softmax", "x", x);
        return Nd4j.exec(new SoftMax(x, dimension))[0];
    }

    public INDArray softmax(INDArray x) {
        NDValidation.validateNumerical("softmax", "x", x);
        return Nd4j.exec(new SoftMax(x, -1))[0];
    }

    public INDArray softmaxDerivative(INDArray x, INDArray wrt, int dimension) {
        NDValidation.validateNumerical("softmaxDerivative", "x", x);
        NDValidation.validateNumerical("softmaxDerivative", "wrt", wrt);
        return Nd4j.exec(new SoftmaxBp(x, wrt, dimension))[0];
    }

    public INDArray softplus(INDArray x) {
        NDValidation.validateNumerical("softplus", "x", x);
        return Nd4j.exec(new SoftPlus(x));
    }

    public INDArray softsign(INDArray x) {
        NDValidation.validateNumerical("softsign", "x", x);
        return Nd4j.exec(new SoftSign(x));
    }

    public INDArray softsignDerivative(INDArray x) {
        NDValidation.validateNumerical("softsignDerivative", "x", x);
        return Nd4j.exec(new SoftSignDerivative(x));
    }

    public INDArray swish(INDArray x) {
        NDValidation.validateNumerical("swish", "x", x);
        return Nd4j.exec(new Swish(x));
    }

    public INDArray tanh(INDArray x) {
        NDValidation.validateNumerical("tanh", "x", x);
        return Nd4j.exec(new Tanh(x));
    }
}

