/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.extensions.cuda;

import jcuda.jcublas.JCublas;
import jcuda.jcudnn.JCudnn;
import jcuda.runtime.JCuda;
import org.dromara.easyai.extensions.cuda.CudaFp32Util;
import org.dromara.easyai.extensions.cuda.MatrixUtil;
import org.dromara.easyai.matrixTools.CudaMatrix;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.jcudax.JCudax;

public class CudaMatrixImpl
implements CudaMatrix {
    public void init() throws Exception {
        JCuda.initialize();
        JCublas.cublasInit();
        JCudnn.initialize();
        JCudax.initialize();
        System.out.println("EasyAI CUDA-12.0.0 extensions init success.");
    }

    public void destroy() throws Exception {
        JCublas.cublasShutdown();
    }

    public void softMax(Matrix a) throws Exception {
        float[] h_A = MatrixUtil.toRowMajorArray(a);
        float[] h_C = CudaFp32Util.matrixSoftmax(h_A, a.getX(), a.getY());
        MatrixUtil.setRowMajorArray(a, h_C);
    }

    public Matrix matrixSoftMaxPd(Matrix qkt, Matrix errorMatrix, float wordVectorDimension) throws Exception {
        float[] h_qkt = MatrixUtil.toRowMajorArray(qkt);
        float[] h_errorMatrix = MatrixUtil.toRowMajorArray(errorMatrix);
        float[] h_grMatrix = new float[qkt.getX() * qkt.getY()];
        CudaFp32Util.matrixSoftMaxPd(h_qkt, h_errorMatrix, h_grMatrix, qkt.getX(), qkt.getY(), wordVectorDimension);
        return MatrixUtil.rowMajorArrayToMatrix(h_grMatrix, qkt.getX(), qkt.getY());
    }

    public Matrix mulMatrix(Matrix a, Matrix b) throws Exception {
        float[] h_A = a.getCudaMatrix();
        float[] h_B = b.getCudaMatrix();
        float[] h_C = CudaFp32Util.matrixMulMatrix(h_A, h_B, a.getX(), a.getY(), b.getY());
        Matrix res = new Matrix(a.getX(), b.getY());
        res.setCudaMatrix(h_C, a.getX(), b.getY());
        return res;
    }

    public void mathAdd(Matrix a, float scalar) throws Exception {
        float[] h_A = a.getCudaMatrix();
        float[] h_C = CudaFp32Util.matrixAddScalar(h_A, scalar, a.getX(), a.getY());
        a.setCudaMatrix(h_C, a.getX(), a.getY());
    }

    public void mathSub(Matrix a, float scalar) throws Exception {
        float[] h_A = a.getCudaMatrix();
        float[] h_C = CudaFp32Util.matrixAddScalar(h_A, -scalar, a.getX(), a.getY());
        a.setCudaMatrix(h_C, a.getX(), a.getY());
    }

    public void mathMul(Matrix a, float scalar) throws Exception {
        float[] h_A = a.getCudaMatrix();
        float[] h_C = CudaFp32Util.matrixScale(h_A, scalar, a.getX(), a.getY());
        a.setCudaMatrix(h_C, a.getX(), a.getY());
    }

    public void mathDiv(Matrix a, float scalar) throws Exception {
        float[] h_A = a.getCudaMatrix();
        float[] h_C = CudaFp32Util.matrixScale(h_A, 1.0f / scalar, a.getX(), a.getY());
        MatrixUtil.setColumnMajorArray(a, h_C);
    }

    static {
        System.out.println("EasyAI CUDA-12.0.0 extensions loaded.");
    }
}

