/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.memory;

import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.BasicMemoryManager;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaMemoryManager
extends BasicMemoryManager {
    private static final Logger log = LoggerFactory.getLogger(CudaMemoryManager.class);

    @Override
    public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) {
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        if (kind == MemoryKind.HOST) {
            Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocHost(bytes, 0);
            if (ptr == null) {
                throw new RuntimeException("Failed to allocate " + bytes + " bytes from HOST memory");
            }
            if (initialize) {
                Pointer.memset(ptr, 0, bytes);
            }
            return ptr;
        }
        if (kind == MemoryKind.DEVICE) {
            Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(bytes, 0, 0);
            log.trace("Allocating {} bytes for device_{}", (Object)bytes, (Object)Nd4j.getAffinityManager().getDeviceForCurrentThread());
            int ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
            if (ec != 0) {
                String em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
                throw new RuntimeException(em + "; Bytes: [" + bytes + "]; Error code [" + ec + "]; DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "]");
            }
            if (ptr == null) {
                throw new RuntimeException("Failed to allocate " + bytes + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory");
            }
            if (initialize) {
                CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
                int i = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, bytes, 0, context.getSpecialStream());
                if (i == 0) {
                    throw new ND4JIllegalStateException("memset failed on device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread());
                }
                context.getSpecialStream().synchronize();
            }
            return ptr;
        }
        throw new RuntimeException("Unknown MemoryKind requested: " + (Object)((Object)kind));
    }

    @Override
    public void collect(INDArray ... arrays) {
        Nd4j.getExecutioner().commit();
        int cnt = -1;
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        for (INDArray array : arrays) {
            ++cnt;
            if (array == null || array.isView()) continue;
            AllocationPoint point = allocator.getAllocationPoint(array);
            if (point.getAllocationStatus() == AllocationStatus.HOST) {
                allocator.getMemoryHandler().free(point, AllocationStatus.HOST);
            } else if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
                allocator.getMemoryHandler().free(point, AllocationStatus.DEVICE);
                allocator.getMemoryHandler().free(point, AllocationStatus.HOST);
            } else if (point.getAllocationStatus() != AllocationStatus.DEALLOCATED) {
                throw new RuntimeException("Unknown AllocationStatus: " + (Object)((Object)point.getAllocationStatus()) + " for argument: " + cnt);
            }
            point.setAllocationStatus(AllocationStatus.DEALLOCATED);
        }
    }

    @Override
    public synchronized void purgeCaches() {
    }

    protected void allocateHostPointers(DataBuffer ... dataBuffers) {
        for (DataBuffer v : dataBuffers) {
            if (v == null || !(v instanceof BaseCudaDataBuffer)) continue;
            ((BaseCudaDataBuffer)v).lazyAllocateHostPointer();
        }
    }

    @Override
    public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) {
            AllocationPoint srcPoint = AtomicAllocator.getInstance().getAllocationPoint(srcBuffer);
            this.allocateHostPointers(dstBuffer, srcBuffer);
            long size = (long)srcBuffer.getElementSize() * srcBuffer.length();
            if (!srcPoint.isActualOnHostSide()) {
                AtomicAllocator.getInstance().synchronizeHostData(srcBuffer);
            }
            Pointer src = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
            Pointer.memcpy(dstBuffer.addressPointer(), src, size);
        } else if (!(dstBuffer instanceof CompressedDataBuffer) && srcBuffer instanceof CompressedDataBuffer) {
            this.allocateHostPointers(dstBuffer, srcBuffer);
            AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(dstBuffer);
            long size = (long)srcBuffer.getElementSize() * srcBuffer.length();
            Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(), size);
            dstPoint.tickHostWrite();
        } else if (dstBuffer instanceof CompressedDataBuffer && srcBuffer instanceof CompressedDataBuffer) {
            this.allocateHostPointers(dstBuffer, srcBuffer);
            Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(), srcBuffer.length() * (long)srcBuffer.getElementSize());
        } else {
            AtomicAllocator.getInstance().memcpy(dstBuffer, srcBuffer);
        }
    }

    @Override
    public void release(Pointer pointer, MemoryKind kind) {
        if (kind == MemoryKind.DEVICE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pointer, 0);
            pointer.setNull();
        } else if (kind == MemoryKind.HOST) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pointer);
            pointer.setNull();
        }
    }

    @Override
    public void setAutoGcWindow(int windowMillis) {
        super.setAutoGcWindow(windowMillis);
        CudaEnvironment.getInstance().getConfiguration().setNoGcWindowMs(windowMillis);
    }

    @Override
    public void memset(INDArray array) {
        if (array.isView()) {
            array.assign(0.0);
            Nd4j.getExecutioner().commit();
            return;
        }
        Nd4j.getExecutioner().push();
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array);
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
            NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(AtomicAllocator.getInstance().getPointer(array, context), 0, array.data().length() * (long)Nd4j.sizeOfDataType(array.data().dataType()), 0, context.getOldStream());
            Pointer.memset(AtomicAllocator.getInstance().getHostPointer(array), 0, array.data().length() * (long)Nd4j.sizeOfDataType(array.data().dataType()));
            context.getOldStream().synchronize();
            point.tickDeviceWrite();
            point.tickHostRead();
        } else if (point.getAllocationStatus() == AllocationStatus.HOST) {
            Nd4j.getExecutioner().commit();
            Pointer.memset(AtomicAllocator.getInstance().getHostPointer(array), 0, array.data().length() * (long)Nd4j.sizeOfDataType(array.data().dataType()));
            point.tickHostWrite();
        }
    }

    @Override
    public Map<Integer, Long> getBandwidthUse() {
        return null;
    }

    @Override
    public long allocatedMemory(Integer deviceId) {
        return AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.GENERAL, deviceId) + AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.WORKSPACE, deviceId);
    }

    @Override
    public void releaseCurrentContext() {
        throw new UnsupportedOperationException("Not implemented yet");
    }
}

