/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class SameDiffLayer
extends AbstractSameDiffLayer {
    protected WeightInit weightInit;
    protected Map<String, IWeightInit> paramWeightInit;

    protected SameDiffLayer(Builder builder) {
        super(builder);
        this.weightInit = builder.weightInit;
        this.paramWeightInit = builder.paramWeightInit;
    }

    protected SameDiffLayer() {
    }

    public abstract SDVariable defineLayer(SameDiff var1, SDVariable var2, Map<String, SDVariable> var3, SDVariable var4);

    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        return new Pair((Object)maskArray, (Object)currentMaskState);
    }

    public void validateInput(INDArray input) {
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        org.deeplearning4j.nn.layers.samediff.SameDiffLayer ret = new org.deeplearning4j.nn.layers.samediff.SameDiffLayer(conf, networkDataType);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    public WeightInit getWeightInit() {
        return this.weightInit;
    }

    public Map<String, IWeightInit> getParamWeightInit() {
        return this.paramWeightInit;
    }

    public void setWeightInit(WeightInit weightInit) {
        this.weightInit = weightInit;
    }

    public void setParamWeightInit(Map<String, IWeightInit> paramWeightInit) {
        this.paramWeightInit = paramWeightInit;
    }

    @Override
    public String toString() {
        return "SameDiffLayer(weightInit=" + (Object)((Object)this.getWeightInit()) + ", paramWeightInit=" + this.getParamWeightInit() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SameDiffLayer)) {
            return false;
        }
        SameDiffLayer other = (SameDiffLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        WeightInit this$weightInit = this.getWeightInit();
        WeightInit other$weightInit = other.getWeightInit();
        if (this$weightInit == null ? other$weightInit != null : !((Object)((Object)this$weightInit)).equals((Object)other$weightInit)) {
            return false;
        }
        Map<String, IWeightInit> this$paramWeightInit = this.getParamWeightInit();
        Map<String, IWeightInit> other$paramWeightInit = other.getParamWeightInit();
        return !(this$paramWeightInit == null ? other$paramWeightInit != null : !((Object)this$paramWeightInit).equals(other$paramWeightInit));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof SameDiffLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        WeightInit $weightInit = this.getWeightInit();
        result = result * 59 + ($weightInit == null ? 43 : ((Object)((Object)$weightInit)).hashCode());
        Map<String, IWeightInit> $paramWeightInit = this.getParamWeightInit();
        result = result * 59 + ($paramWeightInit == null ? 43 : ((Object)$paramWeightInit).hashCode());
        return result;
    }

    public static abstract class Builder<T extends Builder<T>>
    extends AbstractSameDiffLayer.Builder<T> {
        protected WeightInit weightInit = WeightInit.XAVIER;
        protected Map<String, IWeightInit> paramWeightInit;

        public T weightInit(WeightInit weightInit) {
            this.setWeightInit(weightInit);
            return (T)this;
        }

        public T weightInit(@NonNull String param, @NonNull IWeightInit weightInit) {
            if (param == null) {
                throw new NullPointerException("param is marked non-null but is null");
            }
            if (weightInit == null) {
                throw new NullPointerException("weightInit is marked non-null but is null");
            }
            if (this.paramWeightInit == null) {
                this.paramWeightInit = new HashMap<String, IWeightInit>();
            }
            this.paramWeightInit.put(param, weightInit);
            return (T)this;
        }

        public WeightInit getWeightInit() {
            return this.weightInit;
        }

        public Map<String, IWeightInit> getParamWeightInit() {
            return this.paramWeightInit;
        }

        public void setWeightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
        }

        public void setParamWeightInit(Map<String, IWeightInit> paramWeightInit) {
            this.paramWeightInit = paramWeightInit;
        }
    }
}

