/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.loading.SingleTypeRelationships;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.LeafTask;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.UndirectedEdgeSplitter;
import org.neo4j.gds.utils.StringFormatting;

public class LinkPredictionRelationshipSampler {
    private final LinkPredictionSplitConfig splitConfig;
    private LinkPredictionTrainConfig trainConfig;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final GraphStore graphStore;

    public LinkPredictionRelationshipSampler(GraphStore graphStore, LinkPredictionSplitConfig splitConfig, LinkPredictionTrainConfig trainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.graphStore = graphStore;
        this.splitConfig = splitConfig;
        this.trainConfig = trainConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    @NotNull
    static LeafTask progressTask(ExpectedSetSizes sizes) {
        return Tasks.leaf((String)"Split relationships", (long)(sizes.trainSize() + sizes.featureInputSize() + sizes.testSize() + sizes.testComplementSize()));
    }

    public void splitAndSampleRelationships(Optional<String> relationshipWeightProperty) {
        this.progressTracker.beginSubTask("Split relationships");
        this.splitConfig.validateAgainstGraphStore(this.graphStore, this.trainConfig.internalTargetRelationshipType());
        if (this.trainConfig.sourceNodeLabel().equals("*") || this.trainConfig.targetNodeLabel().equals("*")) {
            this.progressTracker.logWarning(StringFormatting.formatWithLocale((String)"Using %s for the `sourceNodeLabel` or `targetNodeLabel` results in not ideal negative link sampling.", (Object[])new Object[]{"*"}));
        }
        RelationshipType testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        Collection sourceLabels = ElementTypeValidator.resolve((GraphStore)this.graphStore, List.of(this.trainConfig.sourceNodeLabel()));
        Collection targetLabels = ElementTypeValidator.resolve((GraphStore)this.graphStore, List.of(this.trainConfig.targetNodeLabel()));
        Graph sourceNodes = this.graphStore.getGraph(sourceLabels);
        Graph targetNodes = this.graphStore.getGraph(targetLabels);
        Collection sourceAndTargetNodeLabels = this.trainConfig.nodeLabelIdentifiers(this.graphStore);
        Graph graph = this.graphStore.getGraph(sourceAndTargetNodeLabels, this.trainConfig.internalRelationshipTypes(this.graphStore), relationshipWeightProperty);
        this.terminationFlag.assertRunning();
        EdgeSplitter.SplitResult testSplitResult = this.split((IdMap)sourceNodes, (IdMap)targetNodes, graph, relationshipWeightProperty, this.splitConfig.testRelationshipType(), this.splitConfig.testComplementRelationshipType(), this.splitConfig.testFraction());
        Graph testComplementGraph = this.graphStore.getGraph(sourceAndTargetNodeLabels, List.of(this.splitConfig.testComplementRelationshipType()), relationshipWeightProperty);
        this.terminationFlag.assertRunning();
        EdgeSplitter.SplitResult trainSplitResult = this.split((IdMap)sourceNodes, (IdMap)targetNodes, testComplementGraph, relationshipWeightProperty, this.splitConfig.trainRelationshipType(), this.splitConfig.featureInputRelationshipType(), this.splitConfig.trainFraction());
        NegativeSampler negativeSampler = NegativeSampler.of((GraphStore)this.graphStore, (Graph)graph, (Collection)sourceAndTargetNodeLabels, this.splitConfig.negativeRelationshipType(), (double)this.splitConfig.negativeSamplingRatio(), (long)testSplitResult.selectedRelCount(), (long)trainSplitResult.selectedRelCount(), (IdMap)sourceNodes, (IdMap)targetNodes, (Collection)sourceLabels, (Collection)targetLabels, (Optional)this.trainConfig.randomSeed());
        this.terminationFlag.assertRunning();
        negativeSampler.produceNegativeSamples(testSplitResult.selectedRels(), trainSplitResult.selectedRels());
        this.graphStore.addRelationshipType(testSplitResult.selectedRels().build());
        this.graphStore.addRelationshipType(trainSplitResult.selectedRels().build());
        this.validateTestSplit(this.graphStore);
        this.validateTrainSplit(this.graphStore);
        this.graphStore.deleteRelationships(testComplementRelationshipType);
        this.progressTracker.endSubTask("Split relationships");
    }

    private EdgeSplitter.SplitResult split(IdMap sourceNodes, IdMap targetNodes, Graph graph, Optional<String> relationshipWeightProperty, RelationshipType selectedRelType, RelationshipType remainingRelType, double selectedFraction) {
        if (!graph.schema().isUndirected()) {
            throw new IllegalArgumentException("EdgeSplitter requires graph to be UNDIRECTED");
        }
        UndirectedEdgeSplitter splitter = new UndirectedEdgeSplitter(this.trainConfig.randomSeed(), this.graphStore.nodes(), sourceNodes, targetNodes, selectedRelType, remainingRelType, 4);
        EdgeSplitter.SplitResult splitResult = splitter.splitPositiveExamples(graph, selectedFraction, relationshipWeightProperty);
        SingleTypeRelationships remainingRels = splitResult.remainingRels().build();
        this.graphStore.addRelationshipType(remainingRels);
        return splitResult;
    }

    private void validateTestSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testRelationshipType()), 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testComplementRelationshipType()), 3L, "test-complement", "`testFraction` is too high");
    }

    private void validateTrainSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()), 2L, "train", "`trainFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.featureInputRelationshipType()), 1L, "feature-input", "`trainFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()) / (long)this.splitConfig.validationFolds(), 1L, "validation", "`validationFolds` is too high or the `trainFraction` too low");
    }

    static MemoryEstimation splitEstimation(LinkPredictionSplitConfig splitConfig, String targetRelationshipType, Optional<String> relationshipWeight) {
        RelationshipType checkTargetRelType = targetRelationshipType.equals("*") ? RelationshipType.ALL_RELATIONSHIPS : RelationshipType.of((String)targetRelationshipType);
        return MemoryEstimations.builder((String)"Split relationships").add(LinkPredictionRelationshipSampler.estimatePositiveRelations(checkTargetRelType.name, splitConfig.testFraction(), splitConfig.trainFraction(), relationshipWeight)).add(LinkPredictionRelationshipSampler.estimateNegativeSampling(checkTargetRelType.name, splitConfig.testFraction(), splitConfig.trainFraction(), splitConfig.negativeSamplingRatio(), splitConfig.negativeRelationshipType())).build();
    }

    public static MemoryEstimation estimatePositiveRelations(String relationshipType, double testFraction, double trainFraction, Optional<String> relationshipWeight) {
        int pessimisticSizePerRel = relationshipWeight.isPresent() ? 24 : 16;
        return MemoryEstimations.builder((String)"Relationship splitter").perGraphDimension("Test and train positive relationships", (graphDimensions, threads) -> {
            long testAndTrainRelCount = (long)((double)graphDimensions.estimatedRelCount(List.of(relationshipType)) * (testFraction + trainFraction - testFraction * trainFraction));
            return MemoryRange.of((long)(testAndTrainRelCount / 2L)).times((long)pessimisticSizePerRel);
        }).perGraphDimension("Feature input relationships", (graphDimensions, threads) -> {
            long featureInputRelCount = (long)((double)graphDimensions.estimatedRelCount(List.of(relationshipType)) * (1.0 - testFraction) * (1.0 - trainFraction));
            return MemoryRange.of((long)featureInputRelCount).times((long)pessimisticSizePerRel);
        }).build();
    }

    public static MemoryEstimation estimateNegativeSampling(String relationshipType, double testFraction, double trainFraction, double negativeSamplingRatio, Optional<String> negativeRelationshipType) {
        int sizePerRel = 24;
        return MemoryEstimations.builder((String)"Relationship splitter").perGraphDimension("Negative relationships", (graphDimensions, threads) -> {
            long negativeRelCount = LinkPredictionRelationshipSampler.estimateNegativeRelCount(graphDimensions, relationshipType, testFraction, trainFraction, negativeSamplingRatio, negativeRelationshipType);
            return MemoryRange.of((long)(negativeRelCount / 2L)).times((long)sizePerRel);
        }).build();
    }

    private static long estimateNegativeRelCount(GraphDimensions graphDimensions, String relationshipType, double testFraction, double trainFraction, double negativeSamplingRatio, Optional<String> negativeRelationshipType) {
        if (negativeRelationshipType.isPresent()) {
            return graphDimensions.estimatedRelCount(List.of(negativeRelationshipType.get()));
        }
        double testAndTrainPositiveRelCount = (double)graphDimensions.estimatedRelCount(List.of(relationshipType)) * (testFraction + trainFraction - testFraction * trainFraction);
        return (long)(testAndTrainPositiveRelCount * negativeSamplingRatio);
    }
}

