/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.metrics.classification.OutOfBagError;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.pipeline.AutoTuningConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public final class PipelineCompanion {
    public static final String ANONYMOUS_GRAPH = "__ANONYMOUS_GRAPH__";

    private PipelineCompanion() {
    }

    public static void preparePipelineConfig(Object graphNameOrConfiguration, Map<String, Object> algoConfiguration) {
        if (graphNameOrConfiguration instanceof String) {
            algoConfiguration.put("graphName", graphNameOrConfiguration);
        } else {
            algoConfiguration.put("graphName", ANONYMOUS_GRAPH);
        }
    }

    public static <PIPELINE extends TrainingPipeline<?>, INFO_RESULT> Stream<INFO_RESULT> configureAutoTuning(String userName, String pipelineName, Map<String, Object> configMap, Function<PIPELINE, INFO_RESULT> factory) {
        TrainingPipeline<?> pipeline = PipelineCatalog.get(userName, pipelineName);
        CypherMapWrapper cypherConfig = CypherMapWrapper.create(configMap);
        AutoTuningConfig config = AutoTuningConfig.of(cypherConfig);
        cypherConfig.requireOnlyKeysFrom(config.configKeys());
        pipeline.setAutoTuningConfig(config);
        return Stream.of(factory.apply(pipeline));
    }

    public static void validateMainMetric(TrainingPipeline<?> pipeline, String mainMetric) {
        Set nonRFMethods;
        if (mainMetric.equals(OutOfBagError.OUT_OF_BAG_ERROR.name()) && !(nonRFMethods = pipeline.trainingParameterSpace().entrySet().stream().filter(entry -> entry.getKey() != TrainingMethod.RandomForestClassification && !((List)entry.getValue()).isEmpty()).map(Map.Entry::getKey).map(Enum::toString).collect(Collectors.toSet())).isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"If %s is used as the main metric (the first one), then only RandomForest model candidates are allowed. Incompatible training methods used are: %s.", (Object[])new Object[]{OutOfBagError.OUT_OF_BAG_ERROR.name(), StringJoining.join(nonRFMethods)}));
        }
    }
}

