/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.factory.Nd4j;

public class ActivationGELU
extends BaseActivationFunction {
    private boolean precise;

    public ActivationGELU(boolean precise) {
        this.precise = precise;
    }

    public ActivationGELU() {
        this(false);
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        if (this.precise) {
            Nd4j.getExecutioner().execAndReturn(new PreciseGELU(in, in));
        } else {
            Nd4j.getExecutioner().execAndReturn(new GELU(in, in));
        }
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        this.assertShape(in, epsilon);
        INDArray dLdz = this.precise ? Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in)) : Nd4j.getExecutioner().exec(new GELUDerivative(in, in));
        dLdz.muli(epsilon);
        return new Pair<INDArray, Object>(dLdz, null);
    }

    public String toString() {
        return "gelu(precise=" + this.precise + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationGELU)) {
            return false;
        }
        ActivationGELU other = (ActivationGELU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return this.isPrecise() == other.isPrecise();
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationGELU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isPrecise() ? 79 : 97);
        return result;
    }

    public boolean isPrecise() {
        return this.precise;
    }
}

