/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.nodePipeline;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.CypherMapAccess;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfigImpl;

@Configuration
public interface NodePropertyPredictionSplitConfig
extends ToMapConvertible {
    public static final NodePropertyPredictionSplitConfig DEFAULT_CONFIG = NodePropertyPredictionSplitConfig.of(CypherMapWrapper.empty());

    @Value.Default
    @Configuration.DoubleRange(min=0.0, max=1.0)
    default public double testFraction() {
        return 0.3;
    }

    @Value.Default
    @Configuration.IntegerRange(min=2)
    default public int validationFolds() {
        return 3;
    }

    public static NodePropertyPredictionSplitConfig of(CypherMapWrapper config) {
        return new NodePropertyPredictionSplitConfigImpl((CypherMapAccess)config);
    }

    @Configuration.ToMap
    public Map<String, Object> toMap();

    @Configuration.CollectKeys
    default public Collection<String> configKeys() {
        return Collections.emptyList();
    }

    @Value.Derived
    @Configuration.Ignore
    default public void validateMinNumNodesInSplitSets(Graph graph) {
        long numberNodesInTestSet = (long)((double)graph.nodeCount() * this.testFraction());
        long numberNodesInTrainSet = graph.nodeCount() - numberNodesInTestSet;
        long numberNodesInValidationSet = numberNodesInTrainSet / (long)this.validationFolds();
        NonEmptySetValidation.validateNodeSetSize(numberNodesInTestSet, 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateNodeSetSize(numberNodesInTrainSet, 2L, "train", "`testFraction` is too high");
        NonEmptySetValidation.validateNodeSetSize(numberNodesInValidationSet, 1L, "validation", "`validationFolds` or `testFraction` is too high");
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default public long testSetSize(long nodeCount) {
        return (long)(this.testFraction() * (double)nodeCount);
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default public long trainSetSize(long nodeCount) {
        return (long)((double)nodeCount * (1.0 - this.testFraction()));
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default public long foldTrainSetSize(long nodeCount) {
        return this.trainSetSize(nodeCount) * (long)(this.validationFolds() - 1) / (long)this.validationFolds();
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default public long foldTestSetSize(long nodeCount) {
        return this.trainSetSize(nodeCount) * (long)(1 / this.validationFolds());
    }
}

