/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.iterator;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class KFoldIterator
implements DataSetIterator {
    private static final long serialVersionUID = 6130298603412865817L;
    protected DataSet allData;
    protected int k;
    protected int N;
    protected int[] intervalBoundaries;
    protected int kCursor = 0;
    protected DataSet test;
    protected DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet allData) {
        this(10, allData);
    }

    public KFoldIterator(int k, DataSet allData) {
        if (k <= 1) {
            throw new IllegalArgumentException();
        }
        this.k = k;
        this.N = allData.numExamples();
        this.allData = allData;
        int baseBatchSize = this.N / k;
        int numIncrementedBatches = this.N % k;
        this.intervalBoundaries = new int[k + 1];
        this.intervalBoundaries[0] = 0;
        for (int i = 1; i <= k; ++i) {
            this.intervalBoundaries[i] = i <= numIncrementedBatches ? this.intervalBoundaries[i - 1] + (baseBatchSize + 1) : this.intervalBoundaries[i - 1] + baseBatchSize;
        }
    }

    @Override
    public DataSet next(int num) throws UnsupportedOperationException {
        return null;
    }

    public int totalExamples() {
        return this.N;
    }

    @Override
    public int inputColumns() {
        return (int)this.allData.getFeatures().size(1);
    }

    @Override
    public int totalOutcomes() {
        return (int)this.allData.getLabels().size(1);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    @Override
    public void reset() {
        this.allData.shuffle();
        this.kCursor = 0;
    }

    @Override
    public int batch() {
        return this.intervalBoundaries[this.kCursor + 1] - this.intervalBoundaries[this.kCursor];
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override
    public List<String> getLabels() {
        return this.allData.getLabelNamesList();
    }

    @Override
    public boolean hasNext() {
        return this.kCursor < this.k;
    }

    @Override
    public DataSet next() {
        this.nextFold();
        return this.train;
    }

    @Override
    public void remove() {
    }

    protected void nextFold() {
        int left = this.intervalBoundaries[this.kCursor];
        int right = this.intervalBoundaries[this.kCursor + 1];
        ArrayList<DataSet> kMinusOneFoldList = new ArrayList<DataSet>();
        if (right < this.totalExamples()) {
            if (left > 0) {
                kMinusOneFoldList.add((DataSet)this.allData.getRange(0, left));
            }
            kMinusOneFoldList.add((DataSet)this.allData.getRange(right, this.totalExamples()));
            this.train = DataSet.merge(kMinusOneFoldList);
        } else {
            this.train = (DataSet)this.allData.getRange(0, left);
        }
        this.test = (DataSet)this.allData.getRange(left, right);
        ++this.kCursor;
    }

    public DataSet testFold() {
        return this.test;
    }
}

