/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.blas;

import org.bytedeco.cuda.cublas.cublasContext;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudart.__half;
import org.bytedeco.cuda.global.cublas;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.blas.CudaBlas;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JcublasLevel3
extends BaseLevel3 {
    private static final Logger log = LoggerFactory.getLogger(JcublasLevel3.class);
    private Allocator allocator = AtomicAllocator.getInstance();
    private Nd4jBlas nd4jBlas = (Nd4jBlas)Nd4j.factory().blas();
    private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger logger = LoggerFactory.getLogger(JcublasLevel3.class);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void hgemm(char Order2, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        CublasPointer cAPointer = new CublasPointer(A, ctx);
        CublasPointer cBPointer = new CublasPointer(B, ctx);
        CublasPointer cCPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
            if (arch == 53 || arch == 60 || arch >= 70) {
                __half alphaHalf = new __half();
                __half betaHalf = new __half();
                new ShortPointer(alphaHalf).put((short)HalfIndexer.fromFloat(alpha));
                new ShortPointer(betaHalf).put((short)HalfIndexer.fromFloat(beta));
                cublas.cublasHgemm(new cublasContext(handle), CudaBlas.convertTranspose(TransA), CudaBlas.convertTranspose(TransB), M, N, K, alphaHalf, new __half(cAPointer.getDevicePointer()), lda, new __half(cBPointer.getDevicePointer()), ldb, betaHalf, new __half(cCPointer.getDevicePointer()), ldc);
            } else {
                cublas.cublasSgemmEx(new cublasContext(handle), CudaBlas.convertTranspose(TransA), CudaBlas.convertTranspose(TransB), M, N, K, new FloatPointer(alpha), (Pointer)((ShortPointer)cAPointer.getDevicePointer()), 2, lda, (Pointer)((ShortPointer)cBPointer.getDevicePointer()), 2, ldb, new FloatPointer(beta), (Pointer)((ShortPointer)cCPointer.getDevicePointer()), 2, ldc);
            }
            ctx.getOldStream().synchronize();
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void sgemm(char Order2, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        ctx.getOldStream().synchronize();
        CublasPointer cAPointer = new CublasPointer(A, ctx);
        CublasPointer cBPointer = new CublasPointer(B, ctx);
        CublasPointer cCPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasSgemm_v2(new cublasContext(handle), CudaBlas.convertTranspose(TransA), CudaBlas.convertTranspose(TransB), M, N, K, new FloatPointer(alpha), (FloatPointer)cAPointer.getDevicePointer(), lda, (FloatPointer)cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer)cCPointer.getDevicePointer(), ldc);
            ctx.getOldStream().synchronize();
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void ssymm(char Order2, char Side2, char Uplo, int M, int N, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        CublasPointer cPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasSsymm_v2(new cublasContext(handle), CudaBlas.convertSideMode(Side2), CudaBlas.convertUplo(Uplo), M, N, new FloatPointer(alpha), (FloatPointer)aPointer.getDevicePointer(), lda, (FloatPointer)bPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer)cPointer.getDevicePointer(), ldc);
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void ssyrk(char Order2, char Uplo, char Trans, int N, int K, float alpha, INDArray A, int lda, float beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer cPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasSsyrk_v2(new cublasContext(handle), CudaBlas.convertUplo(Uplo), CudaBlas.convertTranspose(Trans), N, K, new FloatPointer(alpha), (FloatPointer)aPointer.getDevicePointer(), lda, new FloatPointer(beta), (FloatPointer)cPointer.getDevicePointer(), ldc);
        }
        this.allocator.registerAction(ctx, C, A);
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    protected void ssyr2k(char Order2, char Uplo, char Trans, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected void strmm(char Order2, char Side2, char Uplo, char TransA, char Diag2, int M, int N, float alpha, INDArray A, int lda, INDArray B, int ldb) {
        throw new UnsupportedOperationException();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void strsm(char Order2, char Side2, char Uplo, char TransA, char Diag2, int M, int N, float alpha, INDArray A, int lda, INDArray B, int ldb) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(B, A);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasStrsm_v2(new cublasContext(handle), CudaBlas.convertSideMode(Side2), CudaBlas.convertUplo(Uplo), CudaBlas.convertTranspose(TransA), CudaBlas.convertDiag(Diag2), M, N, new FloatPointer(alpha), (FloatPointer)aPointer.getDevicePointer(), lda, (FloatPointer)bPointer.getDevicePointer(), ldb);
        }
        this.allocator.registerAction(ctx, B, A);
        OpExecutionerUtil.checkForAny(B);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dgemm(char Order2, char TransA, char TransB, int M, int N, int K, double alpha, INDArray A, int lda, INDArray B, int ldb, double beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        DataTypeValidation.assertDouble(A, B, C);
        CublasPointer cAPointer = new CublasPointer(A, ctx);
        CublasPointer cBPointer = new CublasPointer(B, ctx);
        CublasPointer cCPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDgemm_v2(new cublasContext(handle), CudaBlas.convertTranspose(TransA), CudaBlas.convertTranspose(TransB), M, N, K, new DoublePointer(alpha), (DoublePointer)cAPointer.getDevicePointer(), lda, (DoublePointer)cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer)cCPointer.getDevicePointer(), ldc);
            ctx.getOldStream().synchronize();
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dsymm(char Order2, char Side2, char Uplo, int M, int N, double alpha, INDArray A, int lda, INDArray B, int ldb, double beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        CublasPointer cPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDsymm_v2(new cublasContext(handle), CudaBlas.convertSideMode(Side2), CudaBlas.convertUplo(Uplo), M, N, new DoublePointer(alpha), (DoublePointer)aPointer.getDevicePointer(), lda, (DoublePointer)bPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer)cPointer.getDevicePointer(), ldc);
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dsyrk(char Order2, char Uplo, char Trans, int N, int K, double alpha, INDArray A, int lda, double beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer cPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDsyrk_v2(new cublasContext(handle), CudaBlas.convertUplo(Uplo), (int)Trans, N, K, new DoublePointer(alpha), (DoublePointer)aPointer.getDevicePointer(), lda, new DoublePointer(beta), (DoublePointer)cPointer.getDevicePointer(), ldc);
        }
        this.allocator.registerAction(ctx, C, A);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dsyr2k(char Order2, char Uplo, char Trans, int N, int K, double alpha, INDArray A, int lda, INDArray B, int ldb, double beta, INDArray C, int ldc) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(C, A, B);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        CublasPointer cPointer = new CublasPointer(C, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDsyr2k_v2(new cublasContext(handle), CudaBlas.convertUplo(Uplo), (int)Trans, N, K, new DoublePointer(alpha), (DoublePointer)aPointer.getDevicePointer(), lda, (DoublePointer)bPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer)cPointer.getDevicePointer(), ldc);
        }
        this.allocator.registerAction(ctx, C, A, B);
        OpExecutionerUtil.checkForAny(C);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dtrmm(char Order2, char Side2, char Uplo, char TransA, char Diag2, int M, int N, double alpha, INDArray A, int lda, INDArray B, int ldb) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(B, A);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDtrmm_v2(new cublasContext(handle), CudaBlas.convertSideMode(Side2), CudaBlas.convertUplo(Uplo), CudaBlas.convertTranspose(TransA), CudaBlas.convertDiag(Diag2), M, N, new DoublePointer(alpha), (DoublePointer)aPointer.getDevicePointer(), lda, (DoublePointer)bPointer.getDevicePointer(), ldb, (DoublePointer)bPointer.getDevicePointer(), ldb);
        }
        this.allocator.registerAction(ctx, B, A);
        OpExecutionerUtil.checkForAny(B);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void dtrsm(char Order2, char Side2, char Uplo, char TransA, char Diag2, int M, int N, double alpha, INDArray A, int lda, INDArray B, int ldb) {
        cublasHandle_t handle;
        Nd4j.getExecutioner().push();
        CudaContext ctx = this.allocator.getFlowController().prepareAction(B, A);
        CublasPointer aPointer = new CublasPointer(A, ctx);
        CublasPointer bPointer = new CublasPointer(B, ctx);
        cublasHandle_t cublasHandle_t2 = handle = ctx.getCublasHandle();
        synchronized (cublasHandle_t2) {
            cublas.cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
            cublas.cublasDtrsm_v2(new cublasContext(handle), CudaBlas.convertSideMode(Side2), CudaBlas.convertUplo(Uplo), CudaBlas.convertTranspose(TransA), CudaBlas.convertDiag(Diag2), M, N, new DoublePointer(alpha), (DoublePointer)aPointer.getDevicePointer(), lda, (DoublePointer)bPointer.getDevicePointer(), ldb);
        }
        this.allocator.registerAction(ctx, B, A);
        OpExecutionerUtil.checkForAny(B);
    }
}

