/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.evaluator.multilayer;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.arbiter.scoring.util.ScoreUtil;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class RegressionDataEvaluator
implements ModelEvaluator {
    private RegressionValue regressionValue;
    private Map<String, Object> params = null;

    public Double evaluateModel(Object model, DataProvider dataProvider) {
        if (model instanceof MultiLayerNetwork) {
            DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(this.params));
            return ScoreUtil.score((MultiLayerNetwork)model, iterator, this.regressionValue);
        }
        DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(this.params));
        return ScoreUtil.score((ComputationGraph)model, iterator, this.regressionValue);
    }

    public List<Class<?>> getSupportedModelTypes() {
        return Arrays.asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    public List<Class<?>> getSupportedDataTypes() {
        return Arrays.asList(DataSetIterator.class, MultiDataSetIterator.class);
    }

    public RegressionDataEvaluator(RegressionValue regressionValue, Map<String, Object> params) {
        this.regressionValue = regressionValue;
        this.params = params;
    }
}

