/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.LockSupport;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback;
import org.deeplearning4j.datasets.iterator.callbacks.DefaultCallback;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsyncDataSetIterator
implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(AsyncDataSetIterator.class);
    protected DataSetIterator backedIterator;
    protected DataSet terminator = new DataSet();
    protected DataSet nextElement = null;
    protected BlockingQueue<DataSet> buffer;
    protected AsyncPrefetchThread thread;
    protected AtomicBoolean shouldWork = new AtomicBoolean(true);
    protected volatile RuntimeException throwable = null;
    protected boolean useWorkspace = true;
    protected int prefetchSize;
    protected String workspaceId;
    protected Integer deviceId;
    protected AtomicBoolean hasDepleted = new AtomicBoolean(false);
    protected DataSetCallback callback;

    protected AsyncDataSetIterator() {
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue) {
        this(iterator, queueSize, queue, true);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize));
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, new DefaultCallback(), deviceId);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, callback);
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, new DefaultCallback());
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.deviceId = deviceId;
        this.callback = callback;
        this.useWorkspace = useWorkspace;
        this.buffer = queue;
        this.prefetchSize = queueSize;
        this.backedIterator = iterator;
        this.workspaceId = "ADSI_ITER-" + UUID.randomUUID().toString();
        if (iterator.resetSupported()) {
            this.backedIterator.reset();
        }
        this.thread = new AsyncPrefetchThread(this.buffer, iterator, this.terminator, null);
        Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.thread, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    public DataSet next(int num) {
        throw new UnsupportedOperationException();
    }

    public int totalExamples() {
        return this.backedIterator.totalExamples();
    }

    public int inputColumns() {
        return this.backedIterator.inputColumns();
    }

    public int totalOutcomes() {
        return this.backedIterator.totalOutcomes();
    }

    public boolean resetSupported() {
        return this.backedIterator.resetSupported();
    }

    public boolean asyncSupported() {
        return false;
    }

    protected void externalCall() {
    }

    public void reset() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
        this.backedIterator.reset();
        this.shouldWork.set(true);
        this.thread = new AsyncPrefetchThread(this.buffer, this.backedIterator, this.terminator, null);
        Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.thread, this.deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
        this.hasDepleted.set(false);
        this.nextElement = null;
    }

    public void shutdown() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
    }

    public int batch() {
        return this.backedIterator.batch();
    }

    public int cursor() {
        return this.backedIterator.cursor();
    }

    public int numExamples() {
        return this.backedIterator.numExamples();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.backedIterator.setPreProcessor(preProcessor);
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.backedIterator.getPreProcessor();
    }

    public List<String> getLabels() {
        return this.backedIterator.getLabels();
    }

    public boolean hasNext() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        try {
            if (this.hasDepleted.get()) {
                return false;
            }
            if (this.nextElement != null && this.nextElement != this.terminator) {
                return true;
            }
            if (this.nextElement == this.terminator) {
                return false;
            }
            this.nextElement = this.buffer.take();
            if (this.nextElement == this.terminator) {
                this.hasDepleted.set(true);
                return false;
            }
            return true;
        }
        catch (Exception e) {
            log.error("Premature end of loop!");
            throw new RuntimeException(e);
        }
    }

    public DataSet next() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        if (this.hasDepleted.get()) {
            return null;
        }
        DataSet temp = this.nextElement;
        this.nextElement = null;
        return temp;
    }

    public void remove() {
    }

    protected class AsyncPrefetchThread
    extends Thread
    implements Runnable {
        private BlockingQueue<DataSet> queue;
        private DataSetIterator iterator;
        private DataSet terminator;
        private AtomicBoolean isShutdown = new AtomicBoolean(false);
        private WorkspaceConfiguration configuration;
        private MemoryWorkspace workspace;

        protected AsyncPrefetchThread(@NonNull BlockingQueue<DataSet> queue, @NonNull DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace) {
            this.configuration = WorkspaceConfiguration.builder().minSize(0xA00000L).overallocationLimit((double)(AsyncDataSetIterator.this.prefetchSize + 1)).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).build();
            if (queue == null) {
                throw new NullPointerException("queue");
            }
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            if (terminator == null) {
                throw new NullPointerException("terminator");
            }
            this.queue = queue;
            this.iterator = iterator;
            this.terminator = terminator;
            this.setDaemon(true);
            this.setName("ADSI prefetch thread");
        }

        /*
         * Unable to fully structure code
         */
        @Override
        public void run() {
            AsyncDataSetIterator.this.externalCall();
            try {
                if (AsyncDataSetIterator.this.useWorkspace) {
                    this.workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.configuration, AsyncDataSetIterator.this.workspaceId);
                }
                while (this.iterator.hasNext() && AsyncDataSetIterator.this.shouldWork.get()) {
                    smth = null;
                    if (AsyncDataSetIterator.this.useWorkspace) {
                        ws = this.workspace.notifyScopeEntered();
                        var3_6 = null;
                        try {
                            smth = (DataSet)this.iterator.next();
                            if (AsyncDataSetIterator.this.callback == null) ** GOTO lbl32
                            AsyncDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.DataSet)smth);
                        }
                        catch (Throwable var4_8) {
                            var3_6 = var4_8;
                            throw var4_8;
                        }
                        finally {
                            if (ws != null) {
                                if (var3_6 != null) {
                                    try {
                                        ws.close();
                                    }
                                    catch (Throwable var4_7) {
                                        var3_6.addSuppressed(var4_7);
                                    }
                                } else {
                                    ws.close();
                                }
                            }
                        }
                    } else {
                        smth = (DataSet)this.iterator.next();
                        if (AsyncDataSetIterator.this.callback != null) {
                            AsyncDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.DataSet)smth);
                        }
                    }
lbl32:
                    // 5 sources

                    Nd4j.getExecutioner().commit();
                    if (smth == null) continue;
                    this.queue.put(smth);
                }
                this.queue.put(this.terminator);
            }
            catch (InterruptedException e) {
                AsyncDataSetIterator.this.shouldWork.set(false);
            }
            catch (RuntimeException e) {
                AsyncDataSetIterator.this.throwable = e;
                throw new RuntimeException(e);
            }
            catch (Exception e) {
                AsyncDataSetIterator.this.throwable = new RuntimeException(e);
                throw new RuntimeException(e);
            }
            finally {
                this.isShutdown.set(true);
            }
        }

        public void shutdown() {
            while (!this.isShutdown.get()) {
                LockSupport.parkNanos(100L);
            }
            if (this.workspace != null) {
                log.debug("Manually destroying ADSI workspace");
                this.workspace.destroyWorkspace(true);
            }
        }
    }
}

