/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.RelationshipWeightConfig;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStep;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;

public class LinkPredictionTrainingPipeline
extends TrainingPipeline<LinkFeatureStep> {
    public static final String PIPELINE_TYPE = "Link prediction training pipeline";
    public static final String MODEL_TYPE = "LinkPrediction";
    private LinkPredictionSplitConfig splitConfig = LinkPredictionSplitConfig.DEFAULT_CONFIG;

    public LinkPredictionTrainingPipeline() {
        super(TrainingPipeline.TrainingType.CLASSIFICATION);
    }

    @Override
    public String type() {
        return PIPELINE_TYPE;
    }

    @Override
    protected Map<String, List<Map<String, Object>>> featurePipelineDescription() {
        return Map.of("nodePropertySteps", ToMapConvertible.toMap((List)this.nodePropertySteps), "featureSteps", ToMapConvertible.toMap((List)this.featureSteps));
    }

    @Override
    protected Map<String, Object> additionalEntries() {
        return Map.of("splitConfig", this.splitConfig.toMap());
    }

    public LinkPredictionSplitConfig splitConfig() {
        return this.splitConfig;
    }

    public void setSplitConfig(LinkPredictionSplitConfig splitConfig) {
        this.splitConfig = splitConfig;
    }

    @Override
    public void specificValidateBeforeExecution(GraphStore graphStore) {
        if (this.featureSteps().isEmpty()) {
            throw new IllegalArgumentException("Training a Link prediction pipeline requires at least one feature. You can add features with the procedure `gds.beta.pipeline.linkPrediction.addFeature`.");
        }
    }

    public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext executionContext) {
        HashMap<String, List<String>> tasksByRelationshipProperty = new HashMap<String, List<String>>();
        for (ExecutableNodePropertyStep existingStep : this.nodePropertySteps()) {
            Map<String, Object> config = existingStep.config();
            Optional<String> maybeProperty = LinkPredictionTrainingPipeline.extractRelationshipProperty(executionContext, config);
            maybeProperty.ifPresent(property -> {
                List tasks = tasksByRelationshipProperty.computeIfAbsent((String)property, key -> new ArrayList());
                tasks.add(existingStep.procName());
            });
        }
        return tasksByRelationshipProperty;
    }

    private static Optional<String> extractRelationshipProperty(ExecutionContext executionContext, Map<String, Object> config) {
        if (config.containsKey("relationshipWeightProperty")) {
            String existingProperty = (String)config.get("relationshipWeightProperty");
            return Optional.of(existingProperty);
        }
        if (config.containsKey("modelName")) {
            return Optional.ofNullable(executionContext.modelCatalog().getUntyped(executionContext.username(), (String)config.get("modelName"))).map(Model::trainConfig).filter(trainConfig -> trainConfig instanceof RelationshipWeightConfig).flatMap(trainConfig -> ((RelationshipWeightConfig)trainConfig).relationshipWeightProperty());
        }
        return Optional.empty();
    }

    public Optional<String> relationshipWeightProperty(ExecutionContext executionContext) {
        Set<Map.Entry<String, List<String>>> relationshipWeightPropertySet = this.tasksByRelationshipProperty(executionContext).entrySet();
        return relationshipWeightPropertySet.isEmpty() ? Optional.empty() : Optional.of(relationshipWeightPropertySet.iterator().next().getKey());
    }
}

