/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.config.GraphNameConfig;
import org.neo4j.gds.config.RandomSeedConfig;
import org.neo4j.gds.core.CypherMapAccess;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.metrics.LinkMetric;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfigImpl;
import org.neo4j.gds.model.ModelConfig;
import org.neo4j.gds.utils.StringFormatting;

@Configuration
public interface LinkPredictionTrainConfig
extends AlgoBaseConfig,
GraphNameConfig,
ModelConfig,
RandomSeedConfig {
    @Value.Default
    @Configuration.DoubleRange(min=0.0, minInclusive=false)
    default public double negativeClassWeight() {
        return 1.0;
    }

    public String pipeline();

    public String targetRelationshipType();

    default public String sourceNodeLabel() {
        return "*";
    }

    default public String targetNodeLabel() {
        return "*";
    }

    @Configuration.Ignore
    default public List<String> relationshipTypes() {
        return List.of(this.targetRelationshipType());
    }

    @Value.Check
    default public void validate() {
        if (this.targetRelationshipType().equals("*")) {
            throw new IllegalArgumentException("'*' is not allowed as targetRelationshipType.");
        }
    }

    @Configuration.Ignore
    default public RelationshipType internalTargetRelationshipType() {
        return RelationshipType.of((String)this.targetRelationshipType());
    }

    @Configuration.Ignore
    default public List<String> nodeLabels() {
        return Stream.of(this.sourceNodeLabel(), this.targetNodeLabel()).distinct().collect(Collectors.toList());
    }

    @Configuration.ConvertWith(method="org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig#namesToMetrics")
    @Configuration.ToMapValue(value="org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig#metricsToNames")
    default public List<Metric> metrics() {
        return List.of(LinkMetric.AUCPR);
    }

    @Configuration.Ignore
    default public Metric mainMetric() {
        return this.metrics().get(0);
    }

    @Configuration.Ignore
    default public List<LinkMetric> linkMetrics() {
        return this.metrics().stream().filter(metric -> !metric.isModelSpecific()).map(metric -> (LinkMetric)metric).collect(Collectors.toList());
    }

    @Configuration.GraphStoreValidationCheck
    default public void validateSourceNodeLabel(GraphStore graphStore, Collection<NodeLabel> selectedLabels, Collection<RelationshipType> selectedRelationshipTypes) {
        ElementTypeValidator.resolveAndValidate((GraphStore)graphStore, List.of(this.sourceNodeLabel()), (String)"sourceNodeLabel");
    }

    @Configuration.GraphStoreValidationCheck
    default public void validateTargetNodeLabel(GraphStore graphStore, Collection<NodeLabel> selectedLabels, Collection<RelationshipType> selectedRelationshipTypes) {
        ElementTypeValidator.resolveAndValidate((GraphStore)graphStore, List.of(this.targetNodeLabel()), (String)"sourceNodeLabel");
    }

    @Configuration.GraphStoreValidationCheck
    default public void validateTargetRelIsUndirected(GraphStore graphStore, Collection<NodeLabel> ignoredNL, Collection<RelationshipType> ignoredRT) {
        if (!graphStore.schema().filterRelationshipTypes(Set.of(this.internalTargetRelationshipType())).isUndirected()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Target relationship type `%s` must be UNDIRECTED, but was directed.", (Object[])new Object[]{this.targetRelationshipType()}));
        }
    }

    public static LinkPredictionTrainConfig of(String username, CypherMapWrapper config) {
        return new LinkPredictionTrainConfigImpl(username, (CypherMapAccess)config);
    }

    public static List<Metric> namesToMetrics(List<?> names) {
        return names.stream().map(LinkMetric::parseLinkMetric).collect(Collectors.toList());
    }

    public static List<String> metricsToNames(List<Metric> metrics) {
        return metrics.stream().map(Metric::name).collect(Collectors.toList());
    }
}

