/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.autoscaler.tuning;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.autoscaler.JobAutoScalerContext;
import org.apache.flink.autoscaler.ScalingSummary;
import org.apache.flink.autoscaler.config.AutoScalerOptions;
import org.apache.flink.autoscaler.event.AutoScalerEventHandler;
import org.apache.flink.autoscaler.metrics.EvaluatedMetrics;
import org.apache.flink.autoscaler.metrics.EvaluatedScalingMetric;
import org.apache.flink.autoscaler.metrics.ScalingMetric;
import org.apache.flink.autoscaler.topology.JobTopology;
import org.apache.flink.autoscaler.topology.ShipStrategy;
import org.apache.flink.autoscaler.topology.VertexInfo;
import org.apache.flink.autoscaler.tuning.ConfigChanges;
import org.apache.flink.autoscaler.tuning.MemoryBudget;
import org.apache.flink.autoscaler.tuning.MemoryScaling;
import org.apache.flink.autoscaler.utils.ResourceCheckUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.IllegalConfigurationException;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.configuration.UnmodifiableConfiguration;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.util.config.memory.CommonProcessMemorySpec;
import org.apache.flink.runtime.util.config.memory.FlinkMemoryUtils;
import org.apache.flink.runtime.util.config.memory.JvmMetaspaceAndOverheadOptions;
import org.apache.flink.runtime.util.config.memory.ProcessMemoryOptions;
import org.apache.flink.runtime.util.config.memory.ProcessMemoryUtils;
import org.apache.flink.runtime.util.config.memory.taskmanager.TaskExecutorFlinkMemory;
import org.apache.flink.runtime.util.config.memory.taskmanager.TaskExecutorFlinkMemoryUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MemoryTuning {
    private static final Logger LOG = LoggerFactory.getLogger(MemoryTuning.class);
    public static final ProcessMemoryUtils<TaskExecutorFlinkMemory> FLINK_MEMORY_UTILS = new ProcessMemoryUtils(MemoryTuning.getMemoryOptions(), (FlinkMemoryUtils)new TaskExecutorFlinkMemoryUtils());
    private static final ConfigChanges EMPTY_CONFIG = new ConfigChanges();

    public static ConfigChanges tuneTaskManagerMemory(JobAutoScalerContext<?> context, EvaluatedMetrics evaluatedMetrics, JobTopology jobTopology, Map<JobVertexID, ScalingSummary> scalingSummaries, AutoScalerEventHandler eventHandler) {
        CommonProcessMemorySpec memSpecs;
        UnmodifiableConfiguration config = new UnmodifiableConfiguration(context.getConfiguration());
        try {
            memSpecs = FLINK_MEMORY_UTILS.memoryProcessSpecFromConfig((Configuration)config);
        }
        catch (IllegalConfigurationException e) {
            LOG.warn("Current memory configuration is not valid. Aborting memory tuning.");
            return EMPTY_CONFIG;
        }
        MemorySize specHeapSize = ((TaskExecutorFlinkMemory)memSpecs.getFlinkMemory()).getJvmHeapMemorySize();
        MemorySize specManagedSize = ((TaskExecutorFlinkMemory)memSpecs.getFlinkMemory()).getManaged();
        MemorySize specNetworkSize = ((TaskExecutorFlinkMemory)memSpecs.getFlinkMemory()).getNetwork();
        MemorySize specMetaspaceSize = memSpecs.getJvmMetaspaceSize();
        LOG.info("Spec memory - heap: {}, managed: {}, network: {}, meta: {}", new Object[]{specHeapSize.toHumanReadableString(), specManagedSize.toHumanReadableString(), specNetworkSize.toHumanReadableString(), specMetaspaceSize.toHumanReadableString()});
        MemorySize maxMemoryBySpec = context.getTaskManagerMemory().orElse(MemorySize.ZERO);
        if (maxMemoryBySpec.compareTo(MemorySize.ZERO) <= 0) {
            LOG.warn("Spec TaskManager memory size could not be determined.");
            return EMPTY_CONFIG;
        }
        MemoryBudget memBudget = new MemoryBudget(maxMemoryBySpec.getBytes());
        memBudget.budget(((TaskExecutorFlinkMemory)memSpecs.getFlinkMemory()).getFrameworkOffHeap().getBytes());
        memBudget.budget(((TaskExecutorFlinkMemory)memSpecs.getFlinkMemory()).getTaskOffHeap().getBytes());
        memBudget.budget(memSpecs.getJvmOverheadSize().getBytes());
        Map<ScalingMetric, EvaluatedScalingMetric> globalMetrics = evaluatedMetrics.getGlobalMetrics();
        MemorySize newNetworkSize = MemoryTuning.adjustNetworkMemory(jobTopology, ResourceCheckUtils.computeNewParallelisms(scalingSummaries, evaluatedMetrics.getVertexMetrics()), (Configuration)config, memBudget);
        MemorySize newMetaspaceSize = MemoryTuning.determineNewSize(MemoryTuning.getUsage(ScalingMetric.METASPACE_MEMORY_USED, globalMetrics), (Configuration)config, memBudget);
        MemorySize newHeapSize = MemoryTuning.determineNewSize(MemoryTuning.getUsage(ScalingMetric.HEAP_MEMORY_USED, globalMetrics), (Configuration)config, memBudget);
        MemorySize newManagedSize = MemoryTuning.adjustManagedMemory(MemoryTuning.getUsage(ScalingMetric.MANAGED_MEMORY_USED, globalMetrics), specManagedSize, (Configuration)config, memBudget);
        newHeapSize = MemoryScaling.applyMemoryScaling(newHeapSize, memBudget, context, scalingSummaries, evaluatedMetrics);
        LOG.info("Optimized memory sizes: heap: {} managed: {}, network: {}, meta: {}", new Object[]{newHeapSize.toHumanReadableString(), newManagedSize.toHumanReadableString(), newNetworkSize.toHumanReadableString(), newMetaspaceSize.toHumanReadableString()});
        long heapDiffBytes = newHeapSize.getBytes() - specHeapSize.getBytes();
        long managedDiffBytes = newManagedSize.getBytes() - specManagedSize.getBytes();
        long networkDiffBytes = newNetworkSize.getBytes() - specNetworkSize.getBytes();
        long flinkMemoryDiffBytes = heapDiffBytes + managedDiffBytes + networkDiffBytes;
        MemorySize totalMemory = new MemorySize(maxMemoryBySpec.getBytes() - memBudget.getRemaining());
        if (totalMemory.compareTo(MemorySize.ZERO) <= 0) {
            LOG.warn("Invalid total memory configuration: {}", (Object)totalMemory);
            return EMPTY_CONFIG;
        }
        ConfigChanges tuningConfig = new ConfigChanges();
        tuningConfig.addOverride(TaskManagerOptions.TOTAL_PROCESS_MEMORY, totalMemory);
        tuningConfig.addRemoval(TaskManagerOptions.TOTAL_FLINK_MEMORY);
        tuningConfig.addRemoval(TaskManagerOptions.TASK_HEAP_MEMORY);
        tuningConfig.addOverride(TaskManagerOptions.FRAMEWORK_HEAP_MEMORY, MemorySize.ZERO);
        MemorySize flinkMemorySize = new MemorySize(memSpecs.getTotalFlinkMemorySize().getBytes() + flinkMemoryDiffBytes);
        tuningConfig.addOverride(TaskManagerOptions.MANAGED_MEMORY_FRACTION, Float.valueOf(MemoryTuning.getFraction(newManagedSize, flinkMemorySize)));
        tuningConfig.addRemoval(TaskManagerOptions.MANAGED_MEMORY_SIZE);
        tuningConfig.addOverride(TaskManagerOptions.NETWORK_MEMORY_MIN, newNetworkSize);
        tuningConfig.addOverride(TaskManagerOptions.NETWORK_MEMORY_MAX, newNetworkSize);
        tuningConfig.addOverride(TaskManagerOptions.JVM_OVERHEAD_FRACTION, Float.valueOf(MemoryTuning.getFraction(memSpecs.getJvmOverheadSize(), totalMemory)));
        tuningConfig.addOverride(TaskManagerOptions.JVM_METASPACE, newMetaspaceSize);
        eventHandler.handleEvent(context, AutoScalerEventHandler.Type.Normal, "Configuration recommendation", String.format("Memory tuning recommends the following configuration (automatic tuning is %s):\n%s", (Boolean)config.get(AutoScalerOptions.MEMORY_TUNING_ENABLED) != false ? "enabled" : "disabled", MemoryTuning.formatConfig(tuningConfig)), "MemoryTuning", (Duration)config.get(AutoScalerOptions.SCALING_EVENT_INTERVAL));
        if (!((Boolean)context.getConfiguration().get(AutoScalerOptions.MEMORY_TUNING_ENABLED)).booleanValue()) {
            return EMPTY_CONFIG;
        }
        return tuningConfig;
    }

    private static MemorySize determineNewSize(MemorySize usage, Configuration config, MemoryBudget memoryBudget) {
        double overheadFactor = 1.0 + (Double)config.get(AutoScalerOptions.MEMORY_TUNING_OVERHEAD);
        long targetSizeBytes = (long)((double)usage.getBytes() * overheadFactor);
        targetSizeBytes = memoryBudget.budget(targetSizeBytes);
        return new MemorySize(targetSizeBytes);
    }

    private static MemorySize adjustManagedMemory(MemorySize managedMemoryUsage, MemorySize managedMemoryConfigured, Configuration config, MemoryBudget memBudget) {
        if (managedMemoryUsage.compareTo(MemorySize.ZERO) <= 0) {
            return MemorySize.ZERO;
        }
        if (((Boolean)config.get(AutoScalerOptions.MEMORY_TUNING_MAXIMIZE_MANAGED_MEMORY)).booleanValue()) {
            long maxManagedMemorySize = memBudget.budget(Long.MAX_VALUE);
            return new MemorySize(maxManagedMemorySize);
        }
        long managedMemorySize = memBudget.budget(managedMemoryConfigured.getBytes());
        return new MemorySize(managedMemorySize);
    }

    private static MemorySize adjustNetworkMemory(JobTopology jobTopology, Map<JobVertexID, Integer> updatedParallelisms, Configuration config, MemoryBudget memBudget) {
        int buffersPerChannel = (Integer)config.get(NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL);
        int floatingBuffers = (Integer)config.get(NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE);
        long memorySegmentBytes = ((MemorySize)config.get(TaskManagerOptions.MEMORY_SEGMENT_SIZE)).getBytes();
        long maxNetworkMemory = 0L;
        for (VertexInfo vertexInfo : jobTopology.getVertexInfos().values()) {
            ShipStrategy shipStrategy;
            for (Map.Entry<JobVertexID, ShipStrategy> inputEntry : vertexInfo.getInputs().entrySet()) {
                JobVertexID inputVertexId = inputEntry.getKey();
                shipStrategy = inputEntry.getValue();
                maxNetworkMemory += (long)MemoryTuning.calculateNetworkSegmentNumber(updatedParallelisms.get(vertexInfo.getId()), updatedParallelisms.get(inputVertexId), shipStrategy, buffersPerChannel, floatingBuffers) * memorySegmentBytes;
            }
            for (Map.Entry<JobVertexID, ShipStrategy> outputEntry : vertexInfo.getOutputs().entrySet()) {
                JobVertexID outputVertexId = outputEntry.getKey();
                shipStrategy = outputEntry.getValue();
                maxNetworkMemory += (long)MemoryTuning.calculateNetworkSegmentNumber(updatedParallelisms.get(vertexInfo.getId()), updatedParallelisms.get(outputVertexId), shipStrategy, buffersPerChannel, floatingBuffers) * memorySegmentBytes;
            }
        }
        return new MemorySize(memBudget.budget(maxNetworkMemory *= (long)((Integer)config.get(TaskManagerOptions.NUM_TASK_SLOTS)).intValue()));
    }

    @VisibleForTesting
    static int calculateNetworkSegmentNumber(int currentVertexParallelism, int connectedVertexParallelism, ShipStrategy shipStrategy, int buffersPerChannel, int floatingBuffers) {
        if (currentVertexParallelism == connectedVertexParallelism && ShipStrategy.FORWARD.equals((Object)shipStrategy)) {
            return buffersPerChannel + floatingBuffers;
        }
        if (ShipStrategy.FORWARD.equals((Object)shipStrategy) || ShipStrategy.RESCALE.equals((Object)shipStrategy)) {
            int channelCount = (int)Math.ceil((double)connectedVertexParallelism / (double)currentVertexParallelism);
            return channelCount * buffersPerChannel + floatingBuffers;
        }
        return connectedVertexParallelism * buffersPerChannel + floatingBuffers;
    }

    private static MemorySize getUsage(ScalingMetric scalingMetric, Map<ScalingMetric, EvaluatedScalingMetric> globalMetrics) {
        MemorySize memoryUsed = new MemorySize((long)globalMetrics.get((Object)scalingMetric).getAverage());
        LOG.debug("{}: {}", (Object)scalingMetric, (Object)memoryUsed);
        return memoryUsed;
    }

    public static MemorySize getTotalMemory(Configuration config, JobAutoScalerContext<?> ctx) {
        MemorySize overrideSize = (MemorySize)config.get(TaskManagerOptions.TOTAL_PROCESS_MEMORY);
        if (overrideSize != null) {
            return overrideSize;
        }
        return ctx.getTaskManagerMemory().orElse(MemorySize.ZERO);
    }

    private static ProcessMemoryOptions getMemoryOptions() {
        return new ProcessMemoryOptions(Arrays.asList(TaskManagerOptions.TASK_HEAP_MEMORY, TaskManagerOptions.MANAGED_MEMORY_SIZE), TaskManagerOptions.TOTAL_FLINK_MEMORY, TaskManagerOptions.TOTAL_PROCESS_MEMORY, new JvmMetaspaceAndOverheadOptions(TaskManagerOptions.JVM_METASPACE, TaskManagerOptions.JVM_OVERHEAD_MIN, TaskManagerOptions.JVM_OVERHEAD_MAX, TaskManagerOptions.JVM_OVERHEAD_FRACTION));
    }

    private static float getFraction(MemorySize enumerator, MemorySize denominator) {
        return BigDecimal.valueOf((double)enumerator.getBytes() / (double)denominator.getBytes()).setScale(3, RoundingMode.CEILING).floatValue();
    }

    private static String formatConfig(ConfigChanges config) {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, String> entry : config.getOverrides().entrySet()) {
            sb.append(entry.getKey()).append(": ").append(entry.getValue()).append(System.lineSeparator());
        }
        if (!config.getRemovals().isEmpty()) {
            sb.append("Remove the following config entries if present: [");
            boolean first = true;
            for (String toRemove : config.getRemovals()) {
                if (first) {
                    first = false;
                } else {
                    sb.append(", ");
                }
                sb.append(toRemove);
            }
            sb.append("]");
        }
        return sb.toString();
    }
}

