/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.nodePipeline.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeSet;
import org.eclipse.collections.api.block.function.primitive.LongToLongFunction;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.regression.RegressionMetrics;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.RegressionTrainerFactory;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.RegressorTrainer;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.PipelineTrainer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.ImmutableNodeRegressionTrainResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainingPipeline;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

public final class NodeRegressionTrain
implements PipelineTrainer<NodeRegressionTrainResult> {
    private final HugeDoubleArray targets;
    private final IdMap nodeIdMap;
    private final NodeRegressionTrainingPipeline pipeline;
    private final List<RegressionMetrics> metrics;
    private final NodeRegressionPipelineTrainConfig trainConfig;
    private final NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer;
    private final ProgressTracker progressTracker;
    private TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

    public static Task progressTask(NodePropertyTrainingPipeline pipeline, long nodeCount) {
        NodePropertyPredictionSplitConfig splitConfig = pipeline.splitConfig();
        long trainSetSize = splitConfig.trainSetSize(nodeCount);
        long testSetSize = splitConfig.testSetSize(nodeCount);
        int validationFolds = splitConfig.validationFolds();
        ArrayList<Object> tasks = new ArrayList<Object>();
        tasks.add(NodePropertyStepExecutor.tasks(pipeline.nodePropertySteps(), nodeCount));
        tasks.addAll(CrossValidation.progressTasks((int)validationFolds, (int)pipeline.numberOfModelSelectionTrials(), (long)trainSetSize));
        tasks.add(ClassifierTrainer.progressTask((String)"Train best model", (long)(5L * trainSetSize)));
        tasks.add(Tasks.leaf((String)"Evaluate on test data", (long)testSetSize));
        tasks.add(ClassifierTrainer.progressTask((String)"Retrain best model", (long)(5L * nodeCount)));
        return Tasks.task((String)"Node Regression Train Pipeline", tasks);
    }

    public static NodeRegressionTrain create(GraphStore graphStore, NodeRegressionTrainingPipeline pipeline, NodeRegressionPipelineTrainConfig config, NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        Graph nodesGraph = graphStore.getGraph(config.targetNodeLabelIdentifiers(graphStore));
        pipeline.splitConfig().validateMinNumNodesInSplitSets(nodesGraph);
        NodePropertyValues targetNodeProperty = nodesGraph.nodeProperties(config.targetProperty());
        HugeDoubleArray targets = HugeDoubleArray.newArray((long)nodesGraph.nodeCount());
        targets.setAll(arg_0 -> ((NodePropertyValues)targetNodeProperty).doubleValue(arg_0));
        return new NodeRegressionTrain(pipeline, config, nodeFeatureProducer, targets, (IdMap)nodesGraph, config.metrics(), progressTracker);
    }

    private NodeRegressionTrain(NodeRegressionTrainingPipeline pipeline, NodeRegressionPipelineTrainConfig trainConfig, NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer, HugeDoubleArray targets, IdMap nodeIdMap, List<RegressionMetrics> metrics, ProgressTracker progressTracker) {
        this.pipeline = pipeline;
        this.trainConfig = trainConfig;
        this.nodeFeatureProducer = nodeFeatureProducer;
        this.nodeIdMap = nodeIdMap;
        this.metrics = metrics;
        this.progressTracker = progressTracker;
        this.targets = targets;
    }

    @Override
    public void setTerminationFlag(TerminationFlag terminationFlag) {
        this.terminationFlag = terminationFlag;
    }

    @Override
    public NodeRegressionTrainResult run() {
        this.progressTracker.beginSubTask();
        NodePropertyPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        NodeSplitter.NodeSplits splits = new NodeSplitter(this.trainConfig.concurrency(), this.nodeIdMap.nodeCount(), this.progressTracker, arg_0 -> ((IdMap)this.nodeIdMap).toOriginalNodeId(arg_0), arg_0 -> ((IdMap)this.nodeIdMap).toMappedNodeId(arg_0)).split(splitConfig.testFraction(), splitConfig.validationFolds(), this.trainConfig.randomSeed());
        this.terminationFlag.assertRunning();
        TrainingStatistics trainingStatistics = new TrainingStatistics(this.metrics);
        Features features = this.nodeFeatureProducer.procedureFeatures(this.pipeline);
        this.findBestModelCandidate(splits.outerSplit().trainSet(), this.metrics, features, trainingStatistics);
        this.evaluateBestModel(splits.outerSplit(), features, trainingStatistics);
        Regressor retrainedModel = this.retrainBestModel(splits.allTrainingExamples(), features, trainingStatistics.bestParameters());
        this.progressTracker.endSubTask();
        return ImmutableNodeRegressionTrainResult.of(retrainedModel, trainingStatistics);
    }

    private void findBestModelCandidate(ReadOnlyHugeLongArray trainNodeIds, List<RegressionMetrics> metrics, Features features, TrainingStatistics trainingStatistics) {
        CrossValidation crossValidation = new CrossValidation(this.progressTracker, this.terminationFlag, metrics, this.pipeline.splitConfig().validationFolds(), this.trainConfig.randomSeed(), (trainSet, config, metricsHandler, messageLogLevel) -> this.trainModel(trainSet, config, features, messageLogLevel), (evaluationSet, regressor, scoreConsumer) -> this.registerMetricScores(evaluationSet, (Regressor)regressor, features, scoreConsumer));
        RandomSearch modelCandidates = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.autoTuningConfig().maxTrials(), this.trainConfig.randomSeed());
        crossValidation.selectModel(trainNodeIds, (LongToLongFunction & Serializable)id -> 0L, new TreeSet<Long>(List.of(Long.valueOf(0L))), trainingStatistics, (Iterator)modelCandidates);
    }

    private void registerMetricScores(ReadOnlyHugeLongArray evaluationSet, Regressor regressor, Features features, MetricConsumer scoreConsumer) {
        HugeDoubleArray localPredictions = HugeDoubleArray.newArray((long)evaluationSet.size());
        ParallelUtil.parallelForEachNode((long)evaluationSet.size(), (int)this.trainConfig.concurrency(), idx -> localPredictions.set(idx, regressor.predict(features.get(evaluationSet.get(idx)))));
        this.terminationFlag.assertRunning();
        HugeDoubleArray localTargets = HugeDoubleArray.newArray((long)evaluationSet.size());
        ParallelUtil.parallelForEachNode((long)evaluationSet.size(), (int)this.trainConfig.concurrency(), idx -> localTargets.set(idx, this.targets.get(evaluationSet.get(idx))));
        this.metrics.forEach(metric -> scoreConsumer.consume((Metric)metric, metric.compute(localTargets, localPredictions)));
    }

    private void evaluateBestModel(TrainingExamplesSplit outerSplit, Features features, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Train best model");
        Regressor bestRegressor = this.trainModel(outerSplit.trainSet(), trainingStatistics.bestParameters(), features, LogLevel.INFO);
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Evaluate on test data");
        this.registerMetricScores(outerSplit.trainSet(), bestRegressor, features, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addOuterTrainScore(arg_0, arg_1));
        Map outerTrainMetrics = trainingStatistics.winningModelOuterTrainMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on full train set: %s", (Object[])new Object[]{outerTrainMetrics}));
        this.registerMetricScores(outerSplit.testSet(), bestRegressor, features, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addTestScore(arg_0, arg_1));
        Map testMetrics = trainingStatistics.winningModelTestMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on test set: %s", (Object[])new Object[]{testMetrics}));
        this.progressTracker.endSubTask("Evaluate on test data");
    }

    private Regressor retrainBestModel(ReadOnlyHugeLongArray trainSet, Features features, TrainerConfig bestParameters) {
        this.progressTracker.beginSubTask("Retrain best model");
        Regressor retrainedRegressor = this.trainModel(trainSet, bestParameters, features, LogLevel.INFO);
        this.progressTracker.endSubTask("Retrain best model");
        return retrainedRegressor;
    }

    private Regressor trainModel(ReadOnlyHugeLongArray trainSet, TrainerConfig trainerConfig, Features features, LogLevel messageLogLevel) {
        RegressorTrainer trainer = RegressionTrainerFactory.create((TrainerConfig)trainerConfig, (TerminationFlag)this.terminationFlag, (ProgressTracker)this.progressTracker, (LogLevel)messageLogLevel, (int)this.trainConfig.concurrency(), (Optional)this.trainConfig.randomSeed());
        return trainer.train(features, this.targets, trainSet);
    }
}

