/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.objdetect;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.objdetect.BoundingBoxesDeserializer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class Yolo2OutputLayer
extends Layer {
    private double lambdaCoord;
    private double lambdaNoObj;
    private ILossFunction lossPositionScale;
    private ILossFunction lossClassPredictions;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=BoundingBoxesDeserializer.class)
    private INDArray boundingBoxes;
    private CNN2DFormat format = CNN2DFormat.NCHW;

    private Yolo2OutputLayer() {
    }

    private Yolo2OutputLayer(Builder builder) {
        super(builder);
        this.lambdaCoord = builder.lambdaCoord;
        this.lambdaNoObj = builder.lambdaNoObj;
        this.lossPositionScale = builder.lossPositionScale;
        this.lossClassPredictions = builder.lossClassPredictions;
        this.boundingBoxes = builder.boundingBoxes;
    }

    @Override
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret = new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(conf, networkDataType);
        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return EmptyParamInitializer.getInstance();
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        return inputType;
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
        this.format = c.getFormat();
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        switch (inputType.getType()) {
            case FF: 
            case RNN: {
                throw new UnsupportedOperationException("Cannot use FF or RNN input types");
            }
            case CNN: {
                return null;
            }
            case CNNFlat: {
                InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat)inputType;
                return new FeedForwardToCnnPreProcessor(cf.getHeight(), cf.getWidth(), cf.getDepth());
            }
        }
        return null;
    }

    @Override
    public List<Regularization> getRegularizationByParam(String paramName) {
        return null;
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return GradientNormalization.None;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return 1.0;
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        long numValues = inputType.arrayElementsPerExample();
        return new LayerMemoryReport.Builder(this.layerName, Yolo2OutputLayer.class, inputType, inputType).standardMemory(0L, 0L).workingMemory(0L, numValues, 0L, 6L * numValues).cacheMemory(0L, 0L).build();
    }

    public double getLambdaCoord() {
        return this.lambdaCoord;
    }

    public double getLambdaNoObj() {
        return this.lambdaNoObj;
    }

    public ILossFunction getLossPositionScale() {
        return this.lossPositionScale;
    }

    public ILossFunction getLossClassPredictions() {
        return this.lossClassPredictions;
    }

    public INDArray getBoundingBoxes() {
        return this.boundingBoxes;
    }

    public CNN2DFormat getFormat() {
        return this.format;
    }

    public void setLambdaCoord(double lambdaCoord) {
        this.lambdaCoord = lambdaCoord;
    }

    public void setLambdaNoObj(double lambdaNoObj) {
        this.lambdaNoObj = lambdaNoObj;
    }

    public void setLossPositionScale(ILossFunction lossPositionScale) {
        this.lossPositionScale = lossPositionScale;
    }

    public void setLossClassPredictions(ILossFunction lossClassPredictions) {
        this.lossClassPredictions = lossClassPredictions;
    }

    public void setBoundingBoxes(INDArray boundingBoxes) {
        this.boundingBoxes = boundingBoxes;
    }

    public void setFormat(CNN2DFormat format) {
        this.format = format;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Yolo2OutputLayer)) {
            return false;
        }
        Yolo2OutputLayer other = (Yolo2OutputLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getLambdaCoord(), other.getLambdaCoord()) != 0) {
            return false;
        }
        if (Double.compare(this.getLambdaNoObj(), other.getLambdaNoObj()) != 0) {
            return false;
        }
        ILossFunction this$lossPositionScale = this.getLossPositionScale();
        ILossFunction other$lossPositionScale = other.getLossPositionScale();
        if (this$lossPositionScale == null ? other$lossPositionScale != null : !this$lossPositionScale.equals(other$lossPositionScale)) {
            return false;
        }
        ILossFunction this$lossClassPredictions = this.getLossClassPredictions();
        ILossFunction other$lossClassPredictions = other.getLossClassPredictions();
        if (this$lossClassPredictions == null ? other$lossClassPredictions != null : !this$lossClassPredictions.equals(other$lossClassPredictions)) {
            return false;
        }
        INDArray this$boundingBoxes = this.getBoundingBoxes();
        INDArray other$boundingBoxes = other.getBoundingBoxes();
        if (this$boundingBoxes == null ? other$boundingBoxes != null : !this$boundingBoxes.equals(other$boundingBoxes)) {
            return false;
        }
        CNN2DFormat this$format = this.getFormat();
        CNN2DFormat other$format = other.getFormat();
        return !(this$format == null ? other$format != null : !this$format.equals(other$format));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof Yolo2OutputLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $lambdaCoord = Double.doubleToLongBits(this.getLambdaCoord());
        result = result * 59 + (int)($lambdaCoord >>> 32 ^ $lambdaCoord);
        long $lambdaNoObj = Double.doubleToLongBits(this.getLambdaNoObj());
        result = result * 59 + (int)($lambdaNoObj >>> 32 ^ $lambdaNoObj);
        ILossFunction $lossPositionScale = this.getLossPositionScale();
        result = result * 59 + ($lossPositionScale == null ? 43 : $lossPositionScale.hashCode());
        ILossFunction $lossClassPredictions = this.getLossClassPredictions();
        result = result * 59 + ($lossClassPredictions == null ? 43 : $lossClassPredictions.hashCode());
        INDArray $boundingBoxes = this.getBoundingBoxes();
        result = result * 59 + ($boundingBoxes == null ? 43 : $boundingBoxes.hashCode());
        CNN2DFormat $format = this.getFormat();
        result = result * 59 + ($format == null ? 43 : $format.hashCode());
        return result;
    }

    @Override
    public String toString() {
        return "Yolo2OutputLayer(lambdaCoord=" + this.getLambdaCoord() + ", lambdaNoObj=" + this.getLambdaNoObj() + ", lossPositionScale=" + this.getLossPositionScale() + ", lossClassPredictions=" + this.getLossClassPredictions() + ", boundingBoxes=" + this.getBoundingBoxes() + ", format=" + this.getFormat() + ")";
    }

    public static class Builder
    extends Layer.Builder<Builder> {
        private double lambdaCoord = 5.0;
        private double lambdaNoObj = 0.5;
        private ILossFunction lossPositionScale = new LossL2();
        private ILossFunction lossClassPredictions = new LossL2();
        private INDArray boundingBoxes;

        public Builder lambdaCoord(double lambdaCoord) {
            this.setLambdaCoord(lambdaCoord);
            return this;
        }

        public Builder lambdaNoObj(double lambdaNoObj) {
            this.setLambdaNoObj(lambdaNoObj);
            return this;
        }

        public Builder lossPositionScale(ILossFunction lossPositionScale) {
            this.setLossPositionScale(lossPositionScale);
            return this;
        }

        public Builder lossClassPredictions(ILossFunction lossClassPredictions) {
            this.setLossClassPredictions(lossClassPredictions);
            return this;
        }

        public Builder boundingBoxPriors(INDArray boundingBoxes) {
            this.setBoundingBoxes(boundingBoxes);
            return this;
        }

        @Override
        public Yolo2OutputLayer build() {
            if (this.boundingBoxes == null) {
                throw new IllegalStateException("Bounding boxes have not been set");
            }
            if (this.boundingBoxes.rank() != 2 || this.boundingBoxes.size(1) != 2L) {
                throw new IllegalStateException("Bounding box priors must have shape [nBoxes, 2]. Has shape: " + Arrays.toString(this.boundingBoxes.shape()));
            }
            return new Yolo2OutputLayer(this);
        }

        public double getLambdaCoord() {
            return this.lambdaCoord;
        }

        public double getLambdaNoObj() {
            return this.lambdaNoObj;
        }

        public ILossFunction getLossPositionScale() {
            return this.lossPositionScale;
        }

        public ILossFunction getLossClassPredictions() {
            return this.lossClassPredictions;
        }

        public INDArray getBoundingBoxes() {
            return this.boundingBoxes;
        }

        public void setLambdaCoord(double lambdaCoord) {
            this.lambdaCoord = lambdaCoord;
        }

        public void setLambdaNoObj(double lambdaNoObj) {
            this.lambdaNoObj = lambdaNoObj;
        }

        public void setLossPositionScale(ILossFunction lossPositionScale) {
            this.lossPositionScale = lossPositionScale;
        }

        public void setLossClassPredictions(ILossFunction lossClassPredictions) {
            this.lossClassPredictions = lossClassPredictions;
        }

        public void setBoundingBoxes(INDArray boundingBoxes) {
            this.boundingBoxes = boundingBoxes;
        }
    }
}

