/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
import lombok.NonNull;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.MessageHandler;
import org.deeplearning4j.optimize.solvers.accumulation.Registerable;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.AtomicThrowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EncodedGradientsAccumulator
implements GradientsAccumulator,
Registerable {
    private static final Logger log = LoggerFactory.getLogger(EncodedGradientsAccumulator.class);
    protected ThreadLocal<INDArray> accumulator = new ThreadLocal();
    protected int parties;
    protected MessageHandler handler;
    protected List<BlockingQueue<INDArray>> messages = new ArrayList<BlockingQueue<INDArray>>();
    protected List<MemoryWorkspace> workspaces = new ArrayList<MemoryWorkspace>();
    protected List<ReentrantLock> locks = new ArrayList<ReentrantLock>();
    protected AtomicInteger workersCounter = new AtomicInteger(0);
    protected ThreadLocal<Integer> index = new ThreadLocal();
    protected long initialMemory = 0x6400000L;
    protected int queueSize = 5;
    protected Double boundary = 1.0;
    protected Queue<INDArray> externalSource;
    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected AtomicBoolean isDone = new AtomicBoolean(true);
    protected AtomicInteger barrier = new AtomicInteger(0);
    protected AtomicInteger secondary = new AtomicInteger(0);
    protected AtomicBoolean registered = new AtomicBoolean(false);
    protected AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected final AtomicInteger currentConsumers = new AtomicInteger(0);
    protected final AtomicThrowable throwable = new AtomicThrowable();
    protected boolean isDebug = false;
    protected final boolean relocatable;
    protected WorkspaceConfiguration appliedConfiguration = WorkspaceConfiguration.builder().minSize(0x500000L).overallocationLimit(0.3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build();

    public EncodedGradientsAccumulator(double parties) {
        this(Nd4j.getAffinityManager().getNumberOfDevices(), 0.001);
    }

    public EncodedGradientsAccumulator(int parties) {
        this(parties, 0.001);
    }

    public EncodedGradientsAccumulator(int parties, double threshold) {
        this(parties, new EncodingHandler(threshold), 0x6400000L, 10, 1.0);
    }

    protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize, Double boundary) {
        if (handler == null) {
            throw new NullPointerException("handler");
        }
        this.parties = parties;
        this.handler = handler;
        this.initialMemory = initialMemory;
        this.queueSize = queueSize;
        this.boundary = boundary;
        WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(initialMemory).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT).policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
        this.relocatable = Nd4j.getAffinityManager().getNumberOfDevices() > 1 && !Nd4j.getAffinityManager().isCrossDeviceAccessSupported();
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        if (parties > numDevices && numDevices != 1) {
            throw new ND4JIllegalStateException("Number of parties [" + parties + "] should be less or equal to number of devices [" + numDevices + "]");
        }
        int curDev = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        for (int i = 0; i < parties; ++i) {
            this.messages.add(new LinkedBlockingQueue(queueSize));
            int cDevice = numDevices > 1 ? i % numDevices : 0;
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(cDevice));
            MemoryWorkspace ws = Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "CGA-" + i, Integer.valueOf(cDevice));
            this.workspaces.add(ws);
            this.locks.add(new ReentrantLock());
        }
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(curDev));
        handler.initialize(this);
    }

    public static int getOptimalBufferSize(int paramsLength, int numWorkers, int queueSize) {
        int bufferSize = (paramsLength / 16 + 65536) * numWorkers * queueSize * 4;
        return bufferSize;
    }

    public static int getOptimalBufferSize(Model model, int numWorkers, int queueSize) {
        return EncodedGradientsAccumulator.getOptimalBufferSize(model.params().length(), numWorkers, queueSize);
    }

    @Override
    public void fallbackToSingleConsumerMode(boolean reallyFallback) {
        if (this.externalSource != null && this.externalSource instanceof Registerable) {
            ((Registerable)((Object)this.externalSource)).fallbackToSingleConsumerMode(reallyFallback);
        }
        this.bypassMode.set(reallyFallback);
    }

    @Override
    public void registerConsumers(int numConsumers) {
        if (this.registered.get()) {
            if (this.isDebug) {
                log.info("Master thread locks at RC");
            }
            while (this.registered.get()) {
                LockSupport.parkNanos(100L);
                if (!this.throwable.isTriggered()) continue;
                throw new RuntimeException(this.throwable.get());
            }
            if (this.isDebug) {
                log.info("Master thread unlocks at RC");
            }
        }
        if (this.externalSource != null && this.externalSource instanceof Registerable) {
            ((Registerable)((Object)this.externalSource)).registerConsumers(numConsumers);
        }
        this.currentConsumers.set(numConsumers);
        this.registered.set(true);
    }

    protected void synchronize(int consumers) {
        this.synchronize(consumers, false);
    }

    protected void synchronize(int consumers, boolean finalLock) {
        if (consumers == 1 || this.bypassMode.get()) {
            if (finalLock) {
                this.registered.set(false);
            }
            return;
        }
        if (this.isDebug) {
            log.info("thread {} locking at CGA: {}", (Object)Thread.currentThread().getId(), (Object)this.currentConsumers.get());
        }
        this.isDone.compareAndSet(true, false);
        if (this.barrier.incrementAndGet() == consumers) {
            this.secondary.set(0);
            this.barrier.set(0);
            this.isFirst.set(false);
            this.isDone.set(true);
        } else {
            while (!this.isDone.get()) {
                LockSupport.parkNanos(1000L);
                if (!this.throwable.isTriggered()) continue;
                throw new RuntimeException(this.throwable.get());
            }
        }
        if (this.secondary.incrementAndGet() == consumers) {
            if (finalLock) {
                this.registered.set(false);
            }
            this.isFirst.set(true);
        } else {
            while (!this.isFirst.get()) {
                LockSupport.parkNanos(1000L);
                if (!this.throwable.isTriggered()) continue;
                throw new RuntimeException(this.throwable.get());
            }
        }
        if (this.isDebug) {
            log.info("thread {} unlocking at CGA: {}", (Object)Thread.currentThread().getId(), (Object)this.currentConsumers.get());
        }
    }

    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates) {
        try {
            Nd4j.getMemoryManager().memset(updates);
            int cnt = 0;
            while (!this.messages.get(this.index.get()).isEmpty()) {
                INDArray compressed = (INDArray)this.messages.get(this.index.get()).poll();
                int encoding = compressed.data().getInt(3L);
                if (encoding == 0) {
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                } else if (encoding == 1) {
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                } else {
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                }
                ++cnt;
            }
            if (cnt > 0 && this.isDebug) {
                log.info("Local updates to be applied: {}", (Object)cnt);
            }
            if (this.externalSource != null) {
                int ent = 0;
                while (!this.externalSource.isEmpty()) {
                    INDArray compressed = this.externalSource.poll();
                    if (this.relocatable) {
                        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(this.appliedConfiguration, "CGA_APPLY");){
                            INDArray compressed_copy = compressed.unsafeDuplication(true);
                            int encoding = compressed.data().getInt(3L);
                            if (encoding == 0) {
                                Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
                            }
                            if (encoding == 1) {
                                Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
                            }
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    } else {
                        int encoding = compressed.data().getInt(3L);
                        if (encoding == 0) {
                            Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                        } else if (encoding == 1) {
                            Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                        } else {
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    }
                    ++cnt;
                    ++ent;
                }
                if (this.isDebug) {
                    log.info("thread {} finished at Externals", (Object)Thread.currentThread().getId());
                }
                if (ent > 0 && this.isDebug) {
                    log.info("External updates to be applied: {}", (Object)ent);
                }
            }
            this.synchronize(this.currentConsumers.get(), true);
            if (cnt > 0) {
                function.step(params, updates);
            }
        }
        catch (Exception e) {
            this.throwable.setIfFirst((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha) {
        try {
            Nd4j.getMemoryManager().memset(updates);
            int cnt = 0;
            while (!this.messages.get(this.index.get()).isEmpty()) {
                INDArray compressed = (INDArray)this.messages.get(this.index.get()).poll();
                int encoding = compressed.data().getInt(3L);
                if (encoding == 0) {
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                } else if (encoding == 1) {
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                } else {
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                }
                ++cnt;
            }
            if (cnt > 0 && this.isDebug) {
                log.info("Local updates to be applied: {}", (Object)cnt);
            }
            if (this.externalSource != null) {
                int ent = 0;
                while (!this.externalSource.isEmpty()) {
                    INDArray compressed = this.externalSource.poll();
                    if (this.relocatable) {
                        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(this.appliedConfiguration, "CGA_APPLY");){
                            INDArray compressed_copy = compressed.unsafeDuplication(true);
                            int encoding = compressed.data().getInt(3L);
                            if (encoding == 0) {
                                Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
                            }
                            if (encoding == 1) {
                                Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
                            }
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    } else {
                        int encoding = compressed.data().getInt(3L);
                        if (encoding == 0) {
                            Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                        } else if (encoding == 1) {
                            Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                        } else {
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    }
                    ++cnt;
                    ++ent;
                }
                if (ent > 0 && this.isDebug) {
                    log.info("External updates to be applied: {}", (Object)ent);
                }
            }
            this.synchronize(this.currentConsumers.get(), true);
            if (cnt > 0) {
                function.step(params, updates, alpha);
            }
        }
        catch (Exception e) {
            this.throwable.setIfFirst((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setExternalSource(Queue<INDArray> source) {
        this.externalSource = source;
    }

    @Override
    public void touch() {
        if (this.index.get() == null) {
            int numDevces = Nd4j.getAffinityManager().getNumberOfDevices();
            if (numDevces > 1 && this.parties > 1) {
                int localIndex = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                this.index.set(localIndex);
            } else {
                this.index.set(this.workersCounter.getAndIncrement());
            }
        }
    }

    @Override
    public void storeUpdate(INDArray array) {
        try {
            if (this.accumulator.get() == null) {
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.accumulator.set(Nd4j.create((int[])array.shape(), (char)array.ordering()));
                }
            }
            this.accumulator.get().addi(array);
            if (this.isDebug) {
                log.info("thread {} locking at Register", (Object)Thread.currentThread().getId());
            }
            if (!this.bypassMode.get()) {
                while (!this.registered.get()) {
                    LockSupport.parkNanos(100L);
                    if (!this.throwable.isTriggered()) continue;
                    throw new RuntimeException(this.throwable.get());
                }
            }
            if (this.isDebug) {
                log.info("thread {} unlocking at Register", (Object)Thread.currentThread().getId());
            }
            this.handler.broadcastUpdates(this.accumulator.get());
            this.synchronize(this.currentConsumers.get());
        }
        catch (Exception e) {
            this.throwable.setIfFirst((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void receiveUpdate(INDArray array) {
        try {
            for (int i = 0; i < this.parties; ++i) {
                this.locks.get(i).lock();
                try (MemoryWorkspace workspace = this.workspaces.get(i).notifyScopeEntered();){
                    if (array.data().length() > this.initialMemory / (long)this.queueSize / (long)Nd4j.sizeOfDataType((DataBuffer.Type)array.data().dataType())) {
                        throw new ND4JIllegalStateException("Not enough memory to handle update: [" + array.data().length() * (long)Nd4j.sizeOfDataType((DataBuffer.Type)array.data().dataType()) + " bytes required]. Please increase memory amount for GradientsAccumulator");
                    }
                    INDArray compressed = array.unsafeDuplication();
                    try {
                        this.messages.get(i).put(compressed);
                    }
                    catch (InterruptedException e) {
                        log.info("Something bad at index_{}", (Object)i);
                        throw new RuntimeException(e);
                    }
                }
                this.locks.get(i).unlock();
            }
        }
        catch (Exception e) {
            this.throwable.setIfFirst((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void reset() {
        this.accumulator = new ThreadLocal();
        this.workersCounter.set(0);
        for (int i = 0; i < this.parties; ++i) {
            this.messages.get(i).clear();
        }
    }

    public static class Builder {
        protected int parties;
        protected double threshold = 0.001;
        protected long initialMemory = 0x6400000L;
        protected int queueSize = 5;
        protected MessageHandler handler;
        protected Double boundary = null;

        public Builder(int parties) {
            if (parties < 1) {
                throw new DL4JInvalidConfigException("Number of parties for GradientsAccumulation should be positive value");
            }
            this.parties = parties;
        }

        public Builder messageHandler(@NonNull MessageHandler handler) {
            if (handler == null) {
                throw new NullPointerException("handler");
            }
            this.handler = handler;
            return this;
        }

        public Builder encodingThreshold(double threshold) {
            this.threshold = threshold;
            return this;
        }

        public Builder updatesBoundary(double boundary) {
            if (boundary >= 1.0) {
                return this;
            }
            if (boundary <= 0.0) {
                throw new DL4JInvalidConfigException("Boundary should have positive value");
            }
            this.boundary = boundary;
            return this;
        }

        public Builder memoryParameters(long initialMemory, int queueSize) {
            this.initialMemory = initialMemory;
            this.queueSize = queueSize;
            return this;
        }

        public EncodedGradientsAccumulator build() {
            if (this.handler == null) {
                this.handler = this.boundary == null ? new EncodingHandler(this.threshold) : new EncodingHandler(this.threshold, this.boundary);
            }
            EncodedGradientsAccumulator accumulator = new EncodedGradientsAccumulator(this.parties, this.handler, this.initialMemory, this.queueSize, this.boundary);
            return accumulator;
        }
    }
}

