/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.util;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.shade.guava.collect.ArrayListMultimap;
import org.nd4j.shade.guava.collect.Multimap;

public class CudaArgs {
    private CudaArgs() {
    }

    public static String getModuleNameFor(Op op) {
        String moduleName = null;
        if (op instanceof ReduceOp) {
            moduleName = "reduce";
            if (op.opName().equals("cosinesimilarity")) {
                moduleName = "reduce3";
            } else if (op.opName().equals("euclidean")) {
                moduleName = "reduce3";
            } else if (op.opName().equals("manhattan")) {
                moduleName = "reduce3";
            }
        } else if (op instanceof TransformOp) {
            moduleName = op.opName().equals("add") ? "pairWiseTransform" : (op.opName().equals("copy") ? "pairWiseTransform" : (op.opName().equals("div") ? "pairWiseTransform" : (op.opName().equals("mul") ? "pairWiseTransform" : (op.opName().equals("rdiv") ? "pairWiseTransform" : (op.opName().equals("rsub") ? "pairWiseTransform" : (op.opName().equals("sub") ? "pairWiseTransform" : "transform"))))));
        } else if (op instanceof ScalarOp) {
            moduleName = "scalar";
        } else if (op instanceof BroadcastOp) {
            moduleName = "broadcast";
        } else if (op instanceof IndexAccumulation) {
            moduleName = "indexReduce";
        }
        return moduleName;
    }

    public static int getOpCode(Op op) {
        int code = -1;
        String name = op.opName();
        if (op instanceof ReduceOp) {
            if (name.equals("mean")) {
                code = 0;
            } else if (name.equals("sum")) {
                code = 1;
            } else if (name.equals("bias")) {
                code = 2;
            } else if (name.equals("max")) {
                code = 3;
            } else if (name.equals("min")) {
                code = 4;
            } else if (name.equals("norm1")) {
                code = 5;
            } else if (name.equals("norm2")) {
                code = 6;
            } else if (name.equals("normmax")) {
                code = 7;
            } else if (name.equals("prod")) {
                code = 8;
            } else if (name.equals("std")) {
                code = 9;
            } else if (name.equals("var")) {
                code = 10;
            } else if (name.equals("manhattan")) {
                code = 0;
            } else if (name.equals("euclidean")) {
                code = 1;
            } else if (name.equals("cosinesimilarity")) {
                code = 2;
            }
        } else if (op instanceof TransformOp) {
            if (name.equals("abs")) {
                code = 0;
            } else if (name.equals("ceil")) {
                code = 1;
            } else if (name.equals("cos")) {
                code = 2;
            } else if (name.equals("exp")) {
                code = 3;
            } else if (name.equals("floor")) {
                code = 4;
            } else if (name.equals("log")) {
                code = 5;
            } else if (name.equals("neg")) {
                code = 6;
            } else if (name.equals("pow")) {
                code = 7;
            } else if (name.equals("round")) {
                code = 8;
            } else if (name.equals("setrange")) {
                code = 9;
            } else if (name.equals("sigmoid")) {
                code = 10;
            } else if (name.equals("sign")) {
                code = 11;
            } else if (name.equals("sin")) {
                code = 12;
            } else if (name.equals("softplus")) {
                code = 13;
            } else if (name.equals("sqrt")) {
                code = 14;
            } else if (name.equals("tanh")) {
                code = 15;
            } else if (name.equals("acos")) {
                code = 16;
            } else if (name.equals("asin")) {
                code = 17;
            } else if (name.equals("atan")) {
                code = 18;
            } else if (name.equals("add")) {
                code = 0;
            } else if (name.equals("copy")) {
                code = 1;
            } else if (name.equals("div")) {
                code = 2;
            } else if (name.equals("eq")) {
                code = 3;
            } else if (name.equals("gt")) {
                code = 4;
            } else if (name.equals("lt")) {
                code = 5;
            } else if (name.equals("mul")) {
                code = 6;
            } else if (name.equals("rdiv")) {
                code = 7;
            } else if (name.equals("rsub")) {
                code = 8;
            } else if (name.equals("sub")) {
                code = 9;
            } else if (name.equals("eps")) {
                code = 10;
            } else if (name.equals("gte")) {
                code = 11;
            } else if (name.equals("lte")) {
                code = 12;
            } else if (name.equals("max")) {
                code = 13;
            } else if (name.equals("min")) {
                code = 14;
            } else if (name.equals("neq")) {
                code = 15;
            }
        } else if (op instanceof ScalarOp) {
            if (name.startsWith("add")) {
                code = 0;
            } else if (name.startsWith("sub")) {
                code = 1;
            } else if (name.startsWith("mul")) {
                code = 2;
            } else if (name.startsWith("div")) {
                code = 3;
            } else if (name.startsWith("rdiv")) {
                code = 4;
            } else if (name.startsWith("rsub")) {
                code = 5;
            } else if (name.startsWith("max")) {
                code = 6;
            } else if (name.startsWith("lessthan")) {
                code = 7;
            } else if (name.startsWith("greaterthan")) {
                code = 8;
            } else if (name.startsWith("eq")) {
                code = 9;
            } else if (name.startsWith("lte")) {
                code = 10;
            } else if (name.startsWith("neq")) {
                code = 11;
            } else if (name.startsWith("min")) {
                code = 12;
            } else if (name.startsWith("set")) {
                code = 13;
            }
        } else if (op instanceof BroadcastOp) {
            if (name.equals("broadcastadd")) {
                code = 0;
            } else if (name.equals("broadcastsub")) {
                code = 1;
            } else if (name.equals("broadcastmul")) {
                code = 2;
            } else if (name.equals("broadcastdiv")) {
                code = 3;
            } else if (name.equals("broadcastrdiv")) {
                code = 4;
            } else if (name.equals("broadcastrsub")) {
                code = 5;
            } else if (name.equals("broadcastcopy")) {
                code = 6;
            }
        } else if (op instanceof IndexAccumulation) {
            if (name.equals("imax")) {
                code = 0;
            } else if (name.equals("imin")) {
                code = 1;
            }
        }
        return code;
    }

    public static int convertMPtoCores(int ccMajor, int ccMinor, int numberOfProcessors) {
        if (ccMajor == 1) {
            return 8;
        }
        if (ccMajor == 2 && ccMinor == 1) {
            return 48;
        }
        if (ccMajor == 2) {
            return 32;
        }
        if (ccMajor == 3) {
            return 192;
        }
        if (ccMajor == 5) {
            return 128;
        }
        return -1;
    }

    public static ArgsAndReferences argsAndReference(CudaContext context, Object ... kernelParams) {
        Object[] kernelParameters = new Object[kernelParams.length];
        ArrayListMultimap<INDArray, CublasPointer> arrayToPointer = ArrayListMultimap.create();
        for (int i = 0; i < kernelParams.length; ++i) {
            CublasPointer pointerToFree;
            Object arg = kernelParams[i];
            if (arg instanceof JCudaBuffer) {
                JCudaBuffer buffer = (JCudaBuffer)arg;
                pointerToFree = new CublasPointer(buffer, context);
                kernelParameters[i] = pointerToFree.getDevicePointer();
                continue;
            }
            if (arg instanceof INDArray) {
                INDArray array = (INDArray)arg;
                pointerToFree = new CublasPointer(array, context);
                kernelParameters[i] = pointerToFree.getDevicePointer();
                arrayToPointer.put(array, pointerToFree);
                continue;
            }
            kernelParameters[i] = arg;
        }
        return new ArgsAndReferences(kernelParameters, arrayToPointer);
    }

    public static class ArgsAndReferences {
        private Object[] args;
        private Multimap<INDArray, CublasPointer> arrayToPointer;

        public Object[] getArgs() {
            return this.args;
        }

        public Multimap<INDArray, CublasPointer> getArrayToPointer() {
            return this.arrayToPointer;
        }

        public void setArgs(Object[] args) {
            this.args = args;
        }

        public void setArrayToPointer(Multimap<INDArray, CublasPointer> arrayToPointer) {
            this.arrayToPointer = arrayToPointer;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ArgsAndReferences)) {
                return false;
            }
            ArgsAndReferences other = (ArgsAndReferences)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getArgs(), other.getArgs())) {
                return false;
            }
            Multimap<INDArray, CublasPointer> this$arrayToPointer = this.getArrayToPointer();
            Multimap<INDArray, CublasPointer> other$arrayToPointer = other.getArrayToPointer();
            return !(this$arrayToPointer == null ? other$arrayToPointer != null : !((Object)this$arrayToPointer).equals(other$arrayToPointer));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ArgsAndReferences;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode(this.getArgs());
            Multimap<INDArray, CublasPointer> $arrayToPointer = this.getArrayToPointer();
            result = result * 59 + ($arrayToPointer == null ? 43 : ((Object)$arrayToPointer).hashCode());
            return result;
        }

        public String toString() {
            return "CudaArgs.ArgsAndReferences(args=" + Arrays.deepToString(this.getArgs()) + ", arrayToPointer=" + this.getArrayToPointer() + ")";
        }

        public ArgsAndReferences(Object[] args, Multimap<INDArray, CublasPointer> arrayToPointer) {
            this.args = args;
            this.arrayToPointer = arrayToPointer;
        }
    }
}

