package org.deeplearning4j.parallelism;

import java.lang.Thread;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator;
import org.deeplearning4j.datasets.iterator.DummyBlockMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.SharedGradient;
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.Registerable;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.parallelism.factory.DefaultTrainerContext;
import org.deeplearning4j.parallelism.factory.SymmetricTrainerContext;
import org.deeplearning4j.parallelism.factory.TrainerContext;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.common.function.Supplier;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper.class */
public class ParallelWrapper implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ParallelWrapper.class);
    protected Supplier<INDArray> modelParamsSupplier;
    protected Supplier<INDArray> updaterParamsSupplier;
    protected AtomicBoolean exceptionEncountered;
    protected Throwable exception;
    protected Model model;
    protected int workers;
    protected int prefetchSize;
    protected Trainer[] zoo;
    protected TrainerContext trainerContext;
    protected StatsStorageRouter storageRouter;
    protected boolean isMQ;
    protected WorkspaceMode workspaceMode;
    protected Object[] trainerContextArgs;
    protected ThreadPoolExecutor executorService;
    protected GradientsAccumulator gradientsAccumulator;
    protected final String uuid = UUID.randomUUID().toString();
    protected int averagingFrequency = 1;
    protected AtomicLong iterationsCounter = new AtomicLong(0);
    protected boolean reportScore = false;
    protected boolean averageUpdaters = true;
    protected boolean legacyAveraging = false;
    protected boolean wasAveraged = false;
    protected AtomicBoolean stopFit = new AtomicBoolean(false);
    protected List<TrainingListener> listeners = new ArrayList();
    protected boolean debug = false;
    protected final AtomicInteger workerCounter = new AtomicInteger(0);
    Thread.UncaughtExceptionHandler handler = new Thread.UncaughtExceptionHandler() { // from class: org.deeplearning4j.parallelism.ParallelWrapper.1
        @Override // java.lang.Thread.UncaughtExceptionHandler
        public void uncaughtException(Thread thread, Throwable th) {
            ParallelWrapper.log.error("Uncaught exception: " + th);
            th.printStackTrace();
            if (ParallelWrapper.this.exceptionEncountered != null) {
                ParallelWrapper.this.exceptionEncountered.set(true);
                ParallelWrapper.this.exception = th;
            }
        }
    };

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$Builder.class */
    public static class Builder<T extends Model> {
        protected T model;
        protected boolean isMQ;
        protected TrainerContext trainerContext;
        protected Object[] trainerContextArgs;
        protected WorkspaceMode workspaceMode;
        protected Supplier<INDArray> modelParamsSupplier;
        protected Supplier<INDArray> updaterParamsSupplier;
        protected ThresholdAlgorithm thresholdAlgorithm;
        protected ResidualPostProcessor residualPostProcessor;
        protected Long encoderMemory;
        protected GradientsAccumulator accumulator;
        protected TrainingMode trainingMode = TrainingMode.AVERAGING;
        protected int workers = Nd4j.getAffinityManager().getNumberOfDevices();
        protected int prefetchSize = 16;
        protected int averagingFrequency = 1;
        protected boolean reportScore = false;
        protected boolean averageUpdaters = true;
        protected boolean legacyAveraging = true;

        public Builder trainerContextArgs(Object... objArr) {
            this.trainerContextArgs = objArr;
            return this;
        }

        public Builder trainerFactory(@NonNull TrainerContext trainerContext) {
            if (trainerContext == null) {
                throw new NullPointerException("trainerContext is marked non-null but is null");
            }
            this.trainerContext = trainerContext;
            return this;
        }

        public Builder workspaceMode(@NonNull WorkspaceMode workspaceMode) {
            if (workspaceMode == null) {
                throw new NullPointerException("mode is marked non-null but is null");
            }
            this.workspaceMode = workspaceMode;
            return this;
        }

        public Builder modelParamsSupplier(Supplier<INDArray> supplier) {
            this.modelParamsSupplier = supplier;
            return this;
        }

        public Builder updaterParamsSupplier(Supplier<INDArray> supplier) {
            this.updaterParamsSupplier = supplier;
            return this;
        }

        public Builder(@NonNull T t) {
            this.isMQ = Nd4j.getAffinityManager().getNumberOfDevices() > 1;
            this.trainerContext = new DefaultTrainerContext();
            this.workspaceMode = WorkspaceMode.ENABLED;
            this.encoderMemory = -1L;
            if (t == null) {
                throw new NullPointerException("model is marked non-null but is null");
            }
            this.model = t;
        }

        public Builder workers(int i) {
            if (i < 2) {
                throw new RuntimeException("Number of workers can't be lower then 2!");
            }
            this.workers = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            if (i < 0) {
                i = 0;
            }
            this.averagingFrequency = i;
            return this;
        }

        public Builder averageUpdaters(boolean z) {
            this.averageUpdaters = z;
            return this;
        }

        public Builder prefetchBuffer(int i) {
            if (i < 0) {
                i = 0;
            }
            this.prefetchSize = i;
            return this;
        }

        public Builder trainingMode(@NonNull TrainingMode trainingMode) {
            if (trainingMode == null) {
                throw new NullPointerException("mode is marked non-null but is null");
            }
            this.trainingMode = trainingMode;
            return this;
        }

        public Builder gradientsAccumulator(@NonNull GradientsAccumulator gradientsAccumulator) {
            if (gradientsAccumulator == null) {
                throw new NullPointerException("accumulator is marked non-null but is null");
            }
            this.accumulator = gradientsAccumulator;
            return this;
        }

        public Builder reportScoreAfterAveraging(boolean z) {
            this.reportScore = z;
            return this;
        }

        public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
            this.thresholdAlgorithm = thresholdAlgorithm;
            return this;
        }

        public Builder temporaryMemory(@NonNull Long l) {
            if (l == null) {
                throw new NullPointerException("numBytes is marked non-null but is null");
            }
            this.encoderMemory = l;
            return this;
        }

        public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor) {
            this.residualPostProcessor = residualPostProcessor;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper parallelWrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            parallelWrapper.averagingFrequency = this.averagingFrequency;
            parallelWrapper.reportScore = this.reportScore;
            parallelWrapper.averageUpdaters = this.averageUpdaters;
            parallelWrapper.legacyAveraging = this.legacyAveraging;
            parallelWrapper.isMQ = this.isMQ;
            parallelWrapper.workspaceMode = this.workspaceMode;
            parallelWrapper.modelParamsSupplier = this.modelParamsSupplier;
            parallelWrapper.updaterParamsSupplier = this.updaterParamsSupplier;
            switch (this.trainingMode) {
                case AVERAGING:
                    this.trainerContext = new DefaultTrainerContext();
                    this.accumulator = null;
                    ParallelWrapper.log.info("Creating new AveragingTraining instance");
                    break;
                case SHARED_GRADIENTS:
                    if (this.thresholdAlgorithm == null) {
                        this.thresholdAlgorithm = new AdaptiveThresholdAlgorithm();
                    }
                    this.trainerContext = new SymmetricTrainerContext();
                    if (this.accumulator == null) {
                        ParallelWrapper.log.info("Creating new GradientsAccumulator instance with default threshold of [5e-4]");
                        this.accumulator = new EncodedGradientsAccumulator(this.workers, new EncodingHandler(this.thresholdAlgorithm, this.residualPostProcessor, Integer.valueOf((int) ((this.model.numParams() / 16) + 5)), false), (this.encoderMemory == null || this.encoderMemory.longValue() < 0) ? r0 * 4 * (this.workers + 3) : this.encoderMemory.longValue(), this.workers + 2, Integer.MAX_VALUE, false);
                        break;
                    }
                    break;
                case CUSTOM:
                    this.trainerContext = new SymmetricTrainerContext();
                    if (this.accumulator == null) {
                        throw new DL4JInvalidConfigException("Please specify GradientsAccumulator fo encoded gradients mode");
                    }
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown trainingMode: [" + this.trainingMode + "]");
            }
            parallelWrapper.trainerContext = this.trainerContext;
            parallelWrapper.gradientsAccumulator = this.accumulator;
            parallelWrapper.init();
            ArrayList arrayList = null;
            if (this.model instanceof MultiLayerNetwork) {
                arrayList = new ArrayList(this.model.getListeners());
                this.model.setListeners(Collections.emptyList());
            } else if (this.model instanceof ComputationGraph) {
                arrayList = new ArrayList(this.model.getListeners());
                this.model.setListeners(Collections.emptyList());
            }
            if (arrayList != null && !arrayList.isEmpty()) {
                parallelWrapper.setListeners(arrayList);
            }
            return parallelWrapper;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$TrainingMode.class */
    public enum TrainingMode {
        AVERAGING,
        SHARED_GRADIENTS,
        CUSTOM
    }

    protected ParallelWrapper(Model model, int i, int i2) {
        this.workers = 2;
        this.prefetchSize = 2;
        this.model = model;
        this.workers = i;
        this.prefetchSize = i2;
        if (this.model instanceof MultiLayerNetwork) {
            this.model.getUpdater();
        } else if (this.model instanceof ComputationGraph) {
            this.model.getUpdater();
        }
    }

    protected void init() {
        this.workerCounter.set(0);
        this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(this.workers, new ThreadFactory() { // from class: org.deeplearning4j.parallelism.ParallelWrapper.2
            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(@NonNull final Runnable runnable) {
                if (runnable == null) {
                    throw new NullPointerException("r is marked non-null but is null");
                }
                final int andIncrement = ParallelWrapper.this.workerCounter.getAndIncrement();
                Thread thread = new Thread(new Runnable() { // from class: org.deeplearning4j.parallelism.ParallelWrapper.2.1
                    @Override // java.lang.Runnable
                    public void run() {
                        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(andIncrement % Nd4j.getAffinityManager().getNumberOfDevices()));
                        runnable.run();
                    }
                });
                thread.setName("ParallelWrapper training thread " + andIncrement);
                thread.setDaemon(true);
                thread.setUncaughtExceptionHandler(ParallelWrapper.this.handler);
                return thread;
            }
        });
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.zoo != null) {
            for (int i = 0; i < this.zoo.length; i++) {
                if (this.zoo[i] != null) {
                    this.zoo[i].shutdown();
                }
            }
            this.zoo = null;
        }
        if (this.executorService != null) {
            this.executorService.shutdown();
            this.executorService = null;
        }
        if (this.gradientsAccumulator != null) {
            this.gradientsAccumulator.reset();
        }
    }

    public synchronized void shutdown() {
        try {
            close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void stopFit() {
        this.stopFit.set(true);
    }

    public synchronized void fit(@NonNull MultiDataSetIterator multiDataSetIterator) {
        INDArray iNDArray;
        INDArray iNDArray2;
        if (multiDataSetIterator == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        this.stopFit.set(false);
        createZooIfNeccessary(true);
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        MultiDataSetIterator multiDataSetIterator2 = multiDataSetIterator;
        if (this.prefetchSize > 0 && multiDataSetIterator.asyncSupported()) {
            if (this.isMQ) {
                if (this.workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0) {
                    log.warn("Number of workers [{}] isn't optimal for available devices [{}]", Integer.valueOf(this.workers), Integer.valueOf(Nd4j.getAffinityManager().getNumberOfDevices()));
                }
                multiDataSetIterator2 = new AsyncMultiDataSetIterator(multiDataSetIterator, this.prefetchSize, new LinkedBlockingQueue(this.prefetchSize * this.workers), true, new InterleavedDataSetCallback(this.prefetchSize * 2));
            } else {
                multiDataSetIterator2 = new AsyncMultiDataSetIterator(multiDataSetIterator, this.prefetchSize);
            }
        }
        AtomicInteger atomicInteger = new AtomicInteger(0);
        DummyBlockMultiDataSetIterator dummyBlockMultiDataSetIterator = new DummyBlockMultiDataSetIterator(multiDataSetIterator2);
        long currentTimeMillis = System.currentTimeMillis();
        while (true) {
            long j = currentTimeMillis;
            if (!dummyBlockMultiDataSetIterator.hasAnything() || this.stopFit.get()) {
                break;
            }
            if (this.modelParamsSupplier != null && (iNDArray2 = (INDArray) this.modelParamsSupplier.get()) != null && this.zoo != null) {
                for (Trainer trainer : this.zoo) {
                    trainer.updateModelParams(iNDArray2);
                }
            }
            if (this.updaterParamsSupplier != null && (iNDArray = (INDArray) this.updaterParamsSupplier.get()) != null && this.zoo != null) {
                for (Trainer trainer2 : this.zoo) {
                    trainer2.updateUpdaterParams(iNDArray);
                }
            }
            MultiDataSet[] next = dummyBlockMultiDataSetIterator.next(this.workers);
            long currentTimeMillis2 = System.currentTimeMillis();
            if (next == null) {
                throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
            }
            atomicInteger.set(next.length);
            if (this.gradientsAccumulator != null && (this.gradientsAccumulator instanceof Registerable)) {
                this.gradientsAccumulator.registerConsumers(next.length);
            }
            for (int i = 0; i < next.length; i++) {
                this.zoo[i].feedMultiDataSet(next[i], currentTimeMillis2 - j);
            }
            this.iterationsCounter.incrementAndGet();
            for (int i2 = 0; i2 < next.length; i2++) {
                this.zoo[i2].waitTillRunning();
            }
            if (this.zoo[0].averagingRequired() && this.iterationsCounter.get() % this.averagingFrequency == 0) {
                averageUpdatersState(atomicInteger, getScore(atomicInteger));
            }
            atomicInteger.set(0);
            currentTimeMillis = System.currentTimeMillis();
        }
        if (this.debug) {
            log.info("Stopping everyone...");
        }
        if (this.debug) {
            log.info("Shutting down iterator...");
        }
        if (this.prefetchSize > 0 && multiDataSetIterator.asyncSupported()) {
            ((AsyncMultiDataSetIterator) multiDataSetIterator2).shutdown();
        }
        try {
            close();
            if (!this.wasAveraged) {
                log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
            }
            log.debug("Iterations passed: {}", Long.valueOf(this.iterationsCounter.get()));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private double getScore(AtomicInteger atomicInteger) {
        this.wasAveraged = true;
        double d = 0.0d;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.workers && i < atomicInteger.get(); i++) {
            arrayList.add(this.zoo[i].getModel().params());
            d += this.zoo[i].getModel().score();
        }
        Nd4j.averageAndPropagate((INDArray) null, arrayList);
        double min = d / Math.min(this.workers, atomicInteger.get());
        if (this.reportScore) {
            log.info("Averaged score: " + min);
        }
        return min;
    }

    private void averageUpdatersState(AtomicInteger atomicInteger, double d) {
        if (this.model instanceof MultiLayerNetwork) {
            if (this.averageUpdaters) {
                Updater updater = this.model.getUpdater();
                int i = 0;
                if (updater != null && updater.getStateViewArray() != null) {
                    ArrayList arrayList = new ArrayList();
                    for (int i2 = 0; i2 < this.workers && i2 < atomicInteger.get(); i2++) {
                        MultiLayerNetwork model = this.zoo[i2].getModel();
                        arrayList.add(model.getUpdater().getStateViewArray());
                        i += model.batchSize();
                    }
                    Nd4j.averageAndPropagate((INDArray) null, arrayList);
                }
            }
            this.model.setScore(d);
            return;
        }
        if (this.model instanceof ComputationGraph) {
            if (this.averageUpdaters) {
                ComputationGraphUpdater updater2 = this.model.getUpdater();
                int i3 = 0;
                if (updater2 != null && updater2.getStateViewArray() != null) {
                    ArrayList arrayList2 = new ArrayList();
                    for (int i4 = 0; i4 < this.workers && i4 < atomicInteger.get(); i4++) {
                        ComputationGraph model2 = this.zoo[i4].getModel();
                        arrayList2.add(model2.getUpdater().getStateViewArray());
                        i3 += model2.batchSize();
                    }
                    Nd4j.averageAndPropagate((INDArray) null, arrayList2);
                }
            }
            this.model.setScore(d);
        }
    }

    public void setListeners(@NonNull Collection<TrainingListener> collection) {
        if (collection == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        setListeners((StatsStorageRouter) null, (Collection<? extends TrainingListener>) collection);
    }

    public void setListeners(@NonNull TrainingListener... trainingListenerArr) {
        if (trainingListenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        setListeners(Arrays.asList(trainingListenerArr));
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, TrainingListener... trainingListenerArr) {
        setListeners(statsStorageRouter, Arrays.asList(trainingListenerArr));
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<? extends TrainingListener> collection) {
        if (collection != null) {
            Iterator<? extends TrainingListener> it = collection.iterator();
            while (it.hasNext()) {
                RoutingIterationListener routingIterationListener = (TrainingListener) it.next();
                if (routingIterationListener instanceof RoutingIterationListener) {
                    RoutingIterationListener routingIterationListener2 = routingIterationListener;
                    if (statsStorageRouter == null && routingIterationListener2.getStorageRouter() == null) {
                        log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", routingIterationListener);
                    }
                }
            }
            this.listeners.addAll(collection);
        } else {
            this.listeners.clear();
        }
        this.storageRouter = statsStorageRouter;
    }

    public void broadcastGradients(SharedGradient sharedGradient) {
    }

    public synchronized void fit(@NonNull DataSetIterator dataSetIterator) {
        INDArray iNDArray;
        INDArray iNDArray2;
        if (dataSetIterator == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        log.info("Using workspaceMode {} for training", this.workspaceMode.name());
        this.stopFit.set(false);
        createZooIfNeccessary(false);
        if (!dataSetIterator.hasNext() && dataSetIterator.resetSupported()) {
            dataSetIterator.reset();
        }
        DataSetIterator dataSetIterator2 = dataSetIterator;
        if (this.prefetchSize > 0 && dataSetIterator.asyncSupported()) {
            log.info("Creating asynchronous prefetcher...");
            if (this.isMQ) {
                if (this.workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0) {
                    log.warn("Number of workers [{}] isn't optimal for available devices [{}]", Integer.valueOf(this.workers), Integer.valueOf(Nd4j.getAffinityManager().getNumberOfDevices()));
                }
                dataSetIterator2 = new AsyncDataSetIterator(dataSetIterator, this.prefetchSize, new LinkedBlockingQueue(this.prefetchSize * this.workers), true, new InterleavedDataSetCallback(this.prefetchSize * 2));
            } else {
                dataSetIterator2 = new AsyncDataSetIterator(dataSetIterator, this.prefetchSize);
            }
        }
        new ArrayList();
        AtomicInteger atomicInteger = new AtomicInteger(0);
        long currentTimeMillis = System.currentTimeMillis();
        log.info("Starting ParallelWrapper training round...");
        long j = 0;
        DummyBlockDataSetIterator dummyBlockDataSetIterator = new DummyBlockDataSetIterator(dataSetIterator2);
        while (dummyBlockDataSetIterator.hasAnything() && !this.stopFit.get()) {
            if (this.modelParamsSupplier != null && (iNDArray2 = (INDArray) this.modelParamsSupplier.get()) != null && this.zoo != null) {
                log.info("Updating model parameters...");
                for (Trainer trainer : this.zoo) {
                    trainer.updateModelParams(iNDArray2);
                }
            }
            if (this.updaterParamsSupplier != null && (iNDArray = (INDArray) this.updaterParamsSupplier.get()) != null && this.zoo != null) {
                log.info("Updating updater parameters...");
                for (Trainer trainer2 : this.zoo) {
                    trainer2.updateUpdaterParams(iNDArray);
                }
            }
            j++;
            DataSet[] next = dummyBlockDataSetIterator.next(this.workers);
            long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
            if (next == null) {
                throw new ND4JIllegalStateException("You can't have NULL as DataSet");
            }
            if (this.zoo == null) {
                throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
            }
            atomicInteger.set(next.length);
            if (this.gradientsAccumulator != null && (this.gradientsAccumulator instanceof Registerable)) {
                this.gradientsAccumulator.registerConsumers(next.length);
            }
            for (int i = 0; i < next.length; i++) {
                if (this.debug) {
                    log.info("Feeding dataset {} to worker {}", Long.valueOf(j), Integer.valueOf(i));
                }
                this.zoo[i].feedDataSet(next[i], currentTimeMillis2);
            }
            this.iterationsCounter.incrementAndGet();
            for (int i2 = 0; i2 < next.length; i2++) {
                try {
                    this.zoo[i2].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % this.averagingFrequency == 0 && this.zoo[0].averagingRequired()) {
                long currentTimeMillis3 = System.currentTimeMillis();
                averageUpdatersState(atomicInteger, getScore(atomicInteger));
                long currentTimeMillis4 = System.currentTimeMillis();
                if (this.reportScore) {
                    log.info("Averaging time: {} ms", Long.valueOf(currentTimeMillis4 - currentTimeMillis3));
                }
            }
            currentTimeMillis = System.currentTimeMillis();
            atomicInteger.set(0);
        }
        if (this.debug) {
            log.info("Stopping everyone...");
        }
        for (int i3 = 0; i3 < this.workers; i3++) {
            try {
                this.zoo[i3].waitTillRunning();
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        }
        if (this.debug) {
            log.info("Shutting down iterator...");
        }
        if (this.prefetchSize > 0 && dataSetIterator.asyncSupported()) {
            ((AsyncDataSetIterator) dataSetIterator2).shutdown();
        }
        try {
            close();
            if (this.debug) {
                log.info("Iterations passed: {}", Long.valueOf(this.iterationsCounter.get()));
            }
        } catch (Exception e3) {
            throw new RuntimeException(e3);
        }
    }

    private void createZooIfNeccessary(boolean z) {
        if (this.zoo == null) {
            this.trainerContext.init(this.model, this.trainerContextArgs);
            this.zoo = new Trainer[this.workers];
            Nd4j.getAffinityManager().getNumberOfDevices();
            for (int i = 0; i < this.workers; i++) {
                this.zoo[i] = this.trainerContext.create(this.uuid, i, this.model, Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), z, this, this.workspaceMode, this.averagingFrequency);
                if (this.executorService == null) {
                    init();
                }
                this.executorService.execute(this.zoo[i]);
            }
        }
    }

    private static TrainingListener cloneListener(TrainingListener trainingListener) {
        return trainingListener instanceof RoutingIterationListener ? ((RoutingIterationListener) trainingListener).clone() : trainingListener;
    }

    private void configureListeners(String str, Collection<TrainingListener> collection, Collection<TrainingListener> collection2) {
        Iterator<TrainingListener> it = collection.iterator();
        while (it.hasNext()) {
            RoutingIterationListener routingIterationListener = (TrainingListener) it.next();
            RoutingIterationListener cloneListener = cloneListener(routingIterationListener);
            if (cloneListener instanceof RoutingIterationListener) {
                RoutingIterationListener routingIterationListener2 = cloneListener;
                routingIterationListener2.setSessionID(routingIterationListener.getSessionID());
                routingIterationListener2.setWorkerID(str);
                StatsStorageRouter storageRouter = routingIterationListener.getStorageRouter();
                if (storageRouter != null) {
                    routingIterationListener2.setStorageRouter(storageRouter);
                } else {
                    routingIterationListener2.setStorageRouter(this.storageRouter);
                }
            }
            collection2.add(cloneListener);
        }
    }

    public Supplier<INDArray> getModelParamsSupplier() {
        return this.modelParamsSupplier;
    }

    public Supplier<INDArray> getUpdaterParamsSupplier() {
        return this.updaterParamsSupplier;
    }

    public AtomicBoolean getExceptionEncountered() {
        return this.exceptionEncountered;
    }

    public Throwable getException() {
        return this.exception;
    }

    public String getUuid() {
        return this.uuid;
    }

    public Model getModel() {
        return this.model;
    }

    public int getWorkers() {
        return this.workers;
    }

    public int getPrefetchSize() {
        return this.prefetchSize;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public Trainer[] getZoo() {
        return this.zoo;
    }

    public TrainerContext getTrainerContext() {
        return this.trainerContext;
    }

    public AtomicLong getIterationsCounter() {
        return this.iterationsCounter;
    }

    public boolean isReportScore() {
        return this.reportScore;
    }

    public boolean isAverageUpdaters() {
        return this.averageUpdaters;
    }

    public boolean isLegacyAveraging() {
        return this.legacyAveraging;
    }

    public boolean isWasAveraged() {
        return this.wasAveraged;
    }

    public AtomicBoolean getStopFit() {
        return this.stopFit;
    }

    public List<TrainingListener> getListeners() {
        return this.listeners;
    }

    public StatsStorageRouter getStorageRouter() {
        return this.storageRouter;
    }

    public boolean isMQ() {
        return this.isMQ;
    }

    public WorkspaceMode getWorkspaceMode() {
        return this.workspaceMode;
    }

    public Object[] getTrainerContextArgs() {
        return this.trainerContextArgs;
    }

    public boolean isDebug() {
        return this.debug;
    }

    public ThreadPoolExecutor getExecutorService() {
        return this.executorService;
    }

    public AtomicInteger getWorkerCounter() {
        return this.workerCounter;
    }

    public Thread.UncaughtExceptionHandler getHandler() {
        return this.handler;
    }

    public void setModelParamsSupplier(Supplier<INDArray> supplier) {
        this.modelParamsSupplier = supplier;
    }

    public void setUpdaterParamsSupplier(Supplier<INDArray> supplier) {
        this.updaterParamsSupplier = supplier;
    }

    public void setExceptionEncountered(AtomicBoolean atomicBoolean) {
        this.exceptionEncountered = atomicBoolean;
    }

    public void setException(Throwable th) {
        this.exception = th;
    }

    public void setModel(Model model) {
        this.model = model;
    }

    public void setWorkers(int i) {
        this.workers = i;
    }

    public void setPrefetchSize(int i) {
        this.prefetchSize = i;
    }

    public void setAveragingFrequency(int i) {
        this.averagingFrequency = i;
    }

    public void setZoo(Trainer[] trainerArr) {
        this.zoo = trainerArr;
    }

    public void setTrainerContext(TrainerContext trainerContext) {
        this.trainerContext = trainerContext;
    }

    public void setIterationsCounter(AtomicLong atomicLong) {
        this.iterationsCounter = atomicLong;
    }

    public void setReportScore(boolean z) {
        this.reportScore = z;
    }

    public void setAverageUpdaters(boolean z) {
        this.averageUpdaters = z;
    }

    public void setLegacyAveraging(boolean z) {
        this.legacyAveraging = z;
    }

    public void setWasAveraged(boolean z) {
        this.wasAveraged = z;
    }

    public void setStopFit(AtomicBoolean atomicBoolean) {
        this.stopFit = atomicBoolean;
    }

    public void setStorageRouter(StatsStorageRouter statsStorageRouter) {
        this.storageRouter = statsStorageRouter;
    }

    public void setMQ(boolean z) {
        this.isMQ = z;
    }

    public void setWorkspaceMode(WorkspaceMode workspaceMode) {
        this.workspaceMode = workspaceMode;
    }

    public void setTrainerContextArgs(Object[] objArr) {
        this.trainerContextArgs = objArr;
    }

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public void setExecutorService(ThreadPoolExecutor threadPoolExecutor) {
        this.executorService = threadPoolExecutor;
    }

    public void setHandler(Thread.UncaughtExceptionHandler uncaughtExceptionHandler) {
        this.handler = uncaughtExceptionHandler;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParallelWrapper)) {
            return false;
        }
        ParallelWrapper parallelWrapper = (ParallelWrapper) obj;
        if (!parallelWrapper.canEqual(this)) {
            return false;
        }
        Supplier<INDArray> modelParamsSupplier = getModelParamsSupplier();
        Supplier<INDArray> modelParamsSupplier2 = parallelWrapper.getModelParamsSupplier();
        if (modelParamsSupplier == null) {
            if (modelParamsSupplier2 != null) {
                return false;
            }
        } else if (!modelParamsSupplier.equals(modelParamsSupplier2)) {
            return false;
        }
        Supplier<INDArray> updaterParamsSupplier = getUpdaterParamsSupplier();
        Supplier<INDArray> updaterParamsSupplier2 = parallelWrapper.getUpdaterParamsSupplier();
        if (updaterParamsSupplier == null) {
            if (updaterParamsSupplier2 != null) {
                return false;
            }
        } else if (!updaterParamsSupplier.equals(updaterParamsSupplier2)) {
            return false;
        }
        AtomicBoolean exceptionEncountered = getExceptionEncountered();
        AtomicBoolean exceptionEncountered2 = parallelWrapper.getExceptionEncountered();
        if (exceptionEncountered == null) {
            if (exceptionEncountered2 != null) {
                return false;
            }
        } else if (!exceptionEncountered.equals(exceptionEncountered2)) {
            return false;
        }
        Throwable exception = getException();
        Throwable exception2 = parallelWrapper.getException();
        if (exception == null) {
            if (exception2 != null) {
                return false;
            }
        } else if (!exception.equals(exception2)) {
            return false;
        }
        String uuid = getUuid();
        String uuid2 = parallelWrapper.getUuid();
        if (uuid == null) {
            if (uuid2 != null) {
                return false;
            }
        } else if (!uuid.equals(uuid2)) {
            return false;
        }
        Model model = getModel();
        Model model2 = parallelWrapper.getModel();
        if (model == null) {
            if (model2 != null) {
                return false;
            }
        } else if (!model.equals(model2)) {
            return false;
        }
        if (getWorkers() != parallelWrapper.getWorkers() || getPrefetchSize() != parallelWrapper.getPrefetchSize() || getAveragingFrequency() != parallelWrapper.getAveragingFrequency() || !Arrays.deepEquals(getZoo(), parallelWrapper.getZoo())) {
            return false;
        }
        TrainerContext trainerContext = getTrainerContext();
        TrainerContext trainerContext2 = parallelWrapper.getTrainerContext();
        if (trainerContext == null) {
            if (trainerContext2 != null) {
                return false;
            }
        } else if (!trainerContext.equals(trainerContext2)) {
            return false;
        }
        AtomicLong iterationsCounter = getIterationsCounter();
        AtomicLong iterationsCounter2 = parallelWrapper.getIterationsCounter();
        if (iterationsCounter == null) {
            if (iterationsCounter2 != null) {
                return false;
            }
        } else if (!iterationsCounter.equals(iterationsCounter2)) {
            return false;
        }
        if (isReportScore() != parallelWrapper.isReportScore() || isAverageUpdaters() != parallelWrapper.isAverageUpdaters() || isLegacyAveraging() != parallelWrapper.isLegacyAveraging() || isWasAveraged() != parallelWrapper.isWasAveraged()) {
            return false;
        }
        AtomicBoolean stopFit = getStopFit();
        AtomicBoolean stopFit2 = parallelWrapper.getStopFit();
        if (stopFit == null) {
            if (stopFit2 != null) {
                return false;
            }
        } else if (!stopFit.equals(stopFit2)) {
            return false;
        }
        List<TrainingListener> listeners = getListeners();
        List<TrainingListener> listeners2 = parallelWrapper.getListeners();
        if (listeners == null) {
            if (listeners2 != null) {
                return false;
            }
        } else if (!listeners.equals(listeners2)) {
            return false;
        }
        StatsStorageRouter storageRouter = getStorageRouter();
        StatsStorageRouter storageRouter2 = parallelWrapper.getStorageRouter();
        if (storageRouter == null) {
            if (storageRouter2 != null) {
                return false;
            }
        } else if (!storageRouter.equals(storageRouter2)) {
            return false;
        }
        if (isMQ() != parallelWrapper.isMQ()) {
            return false;
        }
        WorkspaceMode workspaceMode = getWorkspaceMode();
        WorkspaceMode workspaceMode2 = parallelWrapper.getWorkspaceMode();
        if (workspaceMode == null) {
            if (workspaceMode2 != null) {
                return false;
            }
        } else if (!workspaceMode.equals(workspaceMode2)) {
            return false;
        }
        if (!Arrays.deepEquals(getTrainerContextArgs(), parallelWrapper.getTrainerContextArgs()) || isDebug() != parallelWrapper.isDebug()) {
            return false;
        }
        ThreadPoolExecutor executorService = getExecutorService();
        ThreadPoolExecutor executorService2 = parallelWrapper.getExecutorService();
        if (executorService == null) {
            if (executorService2 != null) {
                return false;
            }
        } else if (!executorService.equals(executorService2)) {
            return false;
        }
        AtomicInteger workerCounter = getWorkerCounter();
        AtomicInteger workerCounter2 = parallelWrapper.getWorkerCounter();
        if (workerCounter == null) {
            if (workerCounter2 != null) {
                return false;
            }
        } else if (!workerCounter.equals(workerCounter2)) {
            return false;
        }
        GradientsAccumulator gradientsAccumulator = getGradientsAccumulator();
        GradientsAccumulator gradientsAccumulator2 = parallelWrapper.getGradientsAccumulator();
        if (gradientsAccumulator == null) {
            if (gradientsAccumulator2 != null) {
                return false;
            }
        } else if (!gradientsAccumulator.equals(gradientsAccumulator2)) {
            return false;
        }
        Thread.UncaughtExceptionHandler handler = getHandler();
        Thread.UncaughtExceptionHandler handler2 = parallelWrapper.getHandler();
        return handler == null ? handler2 == null : handler.equals(handler2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParallelWrapper;
    }

    public int hashCode() {
        Supplier<INDArray> modelParamsSupplier = getModelParamsSupplier();
        int hashCode = (1 * 59) + (modelParamsSupplier == null ? 43 : modelParamsSupplier.hashCode());
        Supplier<INDArray> updaterParamsSupplier = getUpdaterParamsSupplier();
        int hashCode2 = (hashCode * 59) + (updaterParamsSupplier == null ? 43 : updaterParamsSupplier.hashCode());
        AtomicBoolean exceptionEncountered = getExceptionEncountered();
        int hashCode3 = (hashCode2 * 59) + (exceptionEncountered == null ? 43 : exceptionEncountered.hashCode());
        Throwable exception = getException();
        int hashCode4 = (hashCode3 * 59) + (exception == null ? 43 : exception.hashCode());
        String uuid = getUuid();
        int hashCode5 = (hashCode4 * 59) + (uuid == null ? 43 : uuid.hashCode());
        Model model = getModel();
        int hashCode6 = (((((((((hashCode5 * 59) + (model == null ? 43 : model.hashCode())) * 59) + getWorkers()) * 59) + getPrefetchSize()) * 59) + getAveragingFrequency()) * 59) + Arrays.deepHashCode(getZoo());
        TrainerContext trainerContext = getTrainerContext();
        int hashCode7 = (hashCode6 * 59) + (trainerContext == null ? 43 : trainerContext.hashCode());
        AtomicLong iterationsCounter = getIterationsCounter();
        int hashCode8 = (((((((((hashCode7 * 59) + (iterationsCounter == null ? 43 : iterationsCounter.hashCode())) * 59) + (isReportScore() ? 79 : 97)) * 59) + (isAverageUpdaters() ? 79 : 97)) * 59) + (isLegacyAveraging() ? 79 : 97)) * 59) + (isWasAveraged() ? 79 : 97);
        AtomicBoolean stopFit = getStopFit();
        int hashCode9 = (hashCode8 * 59) + (stopFit == null ? 43 : stopFit.hashCode());
        List<TrainingListener> listeners = getListeners();
        int hashCode10 = (hashCode9 * 59) + (listeners == null ? 43 : listeners.hashCode());
        StatsStorageRouter storageRouter = getStorageRouter();
        int hashCode11 = (((hashCode10 * 59) + (storageRouter == null ? 43 : storageRouter.hashCode())) * 59) + (isMQ() ? 79 : 97);
        WorkspaceMode workspaceMode = getWorkspaceMode();
        int hashCode12 = (((((hashCode11 * 59) + (workspaceMode == null ? 43 : workspaceMode.hashCode())) * 59) + Arrays.deepHashCode(getTrainerContextArgs())) * 59) + (isDebug() ? 79 : 97);
        ThreadPoolExecutor executorService = getExecutorService();
        int hashCode13 = (hashCode12 * 59) + (executorService == null ? 43 : executorService.hashCode());
        AtomicInteger workerCounter = getWorkerCounter();
        int hashCode14 = (hashCode13 * 59) + (workerCounter == null ? 43 : workerCounter.hashCode());
        GradientsAccumulator gradientsAccumulator = getGradientsAccumulator();
        int hashCode15 = (hashCode14 * 59) + (gradientsAccumulator == null ? 43 : gradientsAccumulator.hashCode());
        Thread.UncaughtExceptionHandler handler = getHandler();
        return (hashCode15 * 59) + (handler == null ? 43 : handler.hashCode());
    }

    public String toString() {
        return "ParallelWrapper(modelParamsSupplier=" + getModelParamsSupplier() + ", updaterParamsSupplier=" + getUpdaterParamsSupplier() + ", exceptionEncountered=" + getExceptionEncountered() + ", exception=" + getException() + ", uuid=" + getUuid() + ", model=" + getModel() + ", workers=" + getWorkers() + ", prefetchSize=" + getPrefetchSize() + ", averagingFrequency=" + getAveragingFrequency() + ", zoo=" + Arrays.deepToString(getZoo()) + ", trainerContext=" + getTrainerContext() + ", iterationsCounter=" + getIterationsCounter() + ", reportScore=" + isReportScore() + ", averageUpdaters=" + isAverageUpdaters() + ", legacyAveraging=" + isLegacyAveraging() + ", wasAveraged=" + isWasAveraged() + ", stopFit=" + getStopFit() + ", listeners=" + getListeners() + ", storageRouter=" + getStorageRouter() + ", isMQ=" + isMQ() + ", workspaceMode=" + getWorkspaceMode() + ", trainerContextArgs=" + Arrays.deepToString(getTrainerContextArgs()) + ", debug=" + isDebug() + ", executorService=" + getExecutorService() + ", workerCounter=" + getWorkerCounter() + ", gradientsAccumulator=" + getGradientsAccumulator() + ", handler=" + getHandler() + ")";
    }

    public GradientsAccumulator getGradientsAccumulator() {
        return this.gradientsAccumulator;
    }

    public void setGradientsAccumulator(GradientsAccumulator gradientsAccumulator) {
        this.gradientsAccumulator = gradientsAccumulator;
    }
}
