/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.onnx;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.GeneratedMessageV3;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;

public abstract class ONNXRef<T extends GeneratedMessageV3> {
    protected final T backRef;
    private final String baseName;
    protected final ONNXContext context;

    ONNXRef(ONNXContext context, T backRef, String baseName) {
        this.context = context;
        this.backRef = backRef;
        this.baseName = baseName;
    }

    public abstract String getReference();

    public String getBaseName() {
        return this.baseName;
    }

    public ONNXContext onnxContext() {
        return this.context;
    }

    public List<ONNXNode> apply(ONNXOperator op, List<ONNXRef<?>> otherInputs, List<String> outputs, Map<String, Object> attributes) {
        ArrayList allInputs = new ArrayList();
        allInputs.add(this);
        allInputs.addAll(otherInputs);
        return this.context.operation(op, allInputs, outputs, attributes);
    }

    public List<ONNXNode> apply(ONNXOperator op, List<String> outputs, Map<String, Object> attributes) {
        return this.context.operation(op, Collections.singletonList(this), outputs, attributes);
    }

    public ONNXNode apply(ONNXOperator op) {
        return this.context.operation(op, Collections.singletonList(this), this.getBaseName() + "_" + op.getOpName(), Collections.emptyMap());
    }

    public ONNXNode apply(ONNXOperator op, String outputName) {
        return this.context.operation(op, Collections.singletonList(this), outputName, Collections.emptyMap());
    }

    public ONNXNode apply(ONNXOperator op, Map<String, Object> attributes) {
        return this.context.operation(op, Collections.singletonList(this), this.getBaseName() + "_" + op.getOpName(), attributes);
    }

    public ONNXNode apply(ONNXOperator op, ONNXRef<?> other, Map<String, Object> attributes) {
        return this.context.operation(op, Arrays.asList(this, other), this.getBaseName() + "_" + op.getOpName() + "_" + other.getBaseName(), attributes);
    }

    public ONNXNode apply(ONNXOperator op, ONNXRef<?> other, String outputName) {
        return this.context.operation(op, Arrays.asList(this, other), outputName, Collections.emptyMap());
    }

    public ONNXNode apply(ONNXOperator op, ONNXRef<?> other) {
        return this.context.operation(op, Arrays.asList(this, other), this.getBaseName() + "_" + op.getOpName() + "_" + other.getBaseName(), Collections.emptyMap());
    }

    public ONNXNode apply(ONNXOperator op, List<ONNXRef<?>> others) {
        return this.apply(op, others, Collections.singletonList(this.getBaseName() + "_" + others.stream().map(ONNXRef::getBaseName).collect(Collectors.joining("_"))), Collections.emptyMap()).get(0);
    }

    public ONNXNode apply(ONNXOperator op, List<ONNXRef<?>> others, String outputName) {
        return this.apply(op, others, Collections.singletonList(outputName), Collections.emptyMap()).get(0);
    }

    public <Ret extends ONNXRef<?>> Ret assignTo(Ret output) {
        return this.context.assignTo(this, output);
    }

    public ONNXNode cast(Class<?> clazz) {
        if (clazz.equals(Float.TYPE)) {
            return this.apply((ONNXOperator)ONNXOperators.CAST, Collections.singletonMap("to", OnnxMl.TensorProto.DataType.FLOAT.getNumber()));
        }
        if (clazz.equals(Double.TYPE)) {
            return this.apply((ONNXOperator)ONNXOperators.CAST, Collections.singletonMap("to", OnnxMl.TensorProto.DataType.DOUBLE.getNumber()));
        }
        if (clazz.equals(Integer.TYPE)) {
            return this.apply((ONNXOperator)ONNXOperators.CAST, Collections.singletonMap("to", OnnxMl.TensorProto.DataType.INT32.getNumber()));
        }
        if (clazz.equals(Long.TYPE)) {
            return this.apply((ONNXOperator)ONNXOperators.CAST, Collections.singletonMap("to", OnnxMl.TensorProto.DataType.INT64.getNumber()));
        }
        throw new IllegalArgumentException("unsupported class for casting: " + clazz.getName());
    }
}

