/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiDataSetIteratorSplitter {
    private static final Logger log = LoggerFactory.getLogger(MultiDataSetIteratorSplitter.class);
    protected MultiDataSetIterator backedIterator;
    protected final long totalExamples;
    protected final double ratio;
    protected final long numTrain;
    protected final long numTest;
    protected AtomicLong counter = new AtomicLong(0L);
    protected AtomicBoolean resetPending = new AtomicBoolean(false);
    protected MultiDataSet firstTrain = null;

    public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double ratio) {
        if (baseIterator == null) {
            throw new NullPointerException("baseIterator is marked @NonNull but is null");
        }
        if (!(ratio > 0.0) || !(ratio < 1.0)) {
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
        }
        if (totalBatches < 0L) {
            throw new ND4JIllegalStateException("totalExamples number should be positive value");
        }
        if (!baseIterator.resetSupported()) {
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
        }
        this.backedIterator = baseIterator;
        this.totalExamples = totalBatches;
        this.ratio = ratio;
        this.numTrain = (long)((double)this.totalExamples * ratio);
        this.numTest = this.totalExamples - this.numTrain;
        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }

    public MultiDataSetIterator getTrainIterator() {
        return new MultiDataSetIterator(){

            public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
                throw new UnsupportedOperationException("To be implemented yet");
            }

            public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
                MultiDataSetIteratorSplitter.this.backedIterator.setPreProcessor(preProcessor);
            }

            public MultiDataSetPreProcessor getPreProcessor() {
                return MultiDataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            public boolean resetSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            public boolean asyncSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            public void reset() {
                MultiDataSetIteratorSplitter.this.resetPending.set(true);
            }

            public boolean hasNext() {
                boolean state;
                if (MultiDataSetIteratorSplitter.this.resetPending.get()) {
                    if (this.resetSupported()) {
                        MultiDataSetIteratorSplitter.this.backedIterator.reset();
                        MultiDataSetIteratorSplitter.this.counter.set(0L);
                        MultiDataSetIteratorSplitter.this.resetPending.set(false);
                    } else {
                        throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
                    }
                }
                return (state = MultiDataSetIteratorSplitter.this.backedIterator.hasNext()) && MultiDataSetIteratorSplitter.this.counter.get() < MultiDataSetIteratorSplitter.this.numTrain;
            }

            public org.nd4j.linalg.dataset.api.MultiDataSet next() {
                MultiDataSetIteratorSplitter.this.counter.incrementAndGet();
                org.nd4j.linalg.dataset.api.MultiDataSet p = (org.nd4j.linalg.dataset.api.MultiDataSet)MultiDataSetIteratorSplitter.this.backedIterator.next();
                if (MultiDataSetIteratorSplitter.this.counter.get() == 1L && MultiDataSetIteratorSplitter.this.firstTrain == null) {
                    MultiDataSetIteratorSplitter.this.firstTrain = (MultiDataSet)p.copy();
                    MultiDataSetIteratorSplitter.this.firstTrain.detach();
                } else if (MultiDataSetIteratorSplitter.this.counter.get() == 1L) {
                    int cnt = 0;
                    for (INDArray c : p.getFeatures()) {
                        if (c.equalsWithEps((Object)MultiDataSetIteratorSplitter.this.firstTrain.getFeatures()[cnt++], 1.0E-5)) continue;
                        throw new ND4JIllegalStateException("First examples do not match. Randomization was used?");
                    }
                }
                return p;
            }

            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public MultiDataSetIterator getTestIterator() {
        return new MultiDataSetIterator(){

            public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
                throw new UnsupportedOperationException("To be implemented yet");
            }

            public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
                MultiDataSetIteratorSplitter.this.backedIterator.setPreProcessor(preProcessor);
            }

            public MultiDataSetPreProcessor getPreProcessor() {
                return MultiDataSetIteratorSplitter.this.backedIterator.getPreProcessor();
            }

            public boolean resetSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.resetSupported();
            }

            public boolean asyncSupported() {
                return MultiDataSetIteratorSplitter.this.backedIterator.asyncSupported();
            }

            public void reset() {
                MultiDataSetIteratorSplitter.this.resetPending.set(true);
            }

            public boolean hasNext() {
                boolean state = MultiDataSetIteratorSplitter.this.backedIterator.hasNext();
                return state && MultiDataSetIteratorSplitter.this.counter.get() < MultiDataSetIteratorSplitter.this.numTrain + MultiDataSetIteratorSplitter.this.numTest;
            }

            public org.nd4j.linalg.dataset.api.MultiDataSet next() {
                return (org.nd4j.linalg.dataset.api.MultiDataSet)MultiDataSetIteratorSplitter.this.backedIterator.next();
            }

            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }
}

