/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class VAEReconProbScoreCalculator
extends BaseScoreCalculator<Model> {
    protected final int reconstructionProbNumSamples;
    protected final boolean logProb;
    protected final boolean average;

    public VAEReconProbScoreCalculator(DataSetIterator iterator, int reconstructionProbNumSamples, boolean logProb) {
        this(iterator, reconstructionProbNumSamples, logProb, true);
    }

    public VAEReconProbScoreCalculator(DataSetIterator iterator, int reconstructionProbNumSamples, boolean logProb, boolean average) {
        super(iterator);
        this.reconstructionProbNumSamples = reconstructionProbNumSamples;
        this.logProb = logProb;
        this.average = average;
    }

    @Override
    protected void reset() {
        this.scoreSum = 0.0;
        this.minibatchCount = 0;
        this.exampleCount = 0;
    }

    @Override
    protected INDArray output(Model network, INDArray input, INDArray fMask, INDArray lMask) {
        return null;
    }

    @Override
    protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
        return null;
    }

    @Override
    protected double scoreMinibatch(Model net, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) {
        Layer l;
        Model network;
        if (net instanceof MultiLayerNetwork) {
            network = (MultiLayerNetwork)net;
            l = ((MultiLayerNetwork)network).getLayer(0);
        } else {
            network = (ComputationGraph)net;
            l = ((ComputationGraph)network).getLayer(0);
        }
        if (!(l instanceof VariationalAutoencoder)) {
            throw new UnsupportedOperationException("Can only score networks with VariationalAutoencoder layers as first layer - got " + l.getClass().getSimpleName());
        }
        VariationalAutoencoder vae = (VariationalAutoencoder)l;
        if (this.logProb) {
            return -vae.reconstructionLogProbability(features, this.reconstructionProbNumSamples).sumNumber().doubleValue();
        }
        return vae.reconstructionProbability(features, this.reconstructionProbNumSamples).sumNumber().doubleValue();
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
        return 0.0;
    }

    @Override
    protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
        if (this.average) {
            return scoreSum / (double)exampleCount;
        }
        return scoreSum;
    }

    @Override
    public boolean minimizeScore() {
        return false;
    }
}

