/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.module.troubleshooting.internal.operations;

import static java.lang.Runtime.getRuntime;
import static java.lang.String.format;
import static java.lang.System.lineSeparator;
import static java.lang.Thread.sleep;
import static java.lang.management.ManagementFactory.getThreadMXBean;
import static java.util.Comparator.comparingDouble;
import static java.util.Locale.US;

import org.mule.runtime.module.troubleshooting.api.TroubleshootingOperation;
import org.mule.runtime.module.troubleshooting.api.TroubleshootingOperationCallback;
import org.mule.runtime.module.troubleshooting.api.TroubleshootingOperationDefinition;
import org.mule.runtime.module.troubleshooting.internal.DefaultTroubleshootingOperationDefinition;

import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Operation used to collect thread CPU usage information.
 * <p>
 * The name of the operation is "threadCpuUsage".
 * <p>
 * This operation takes two snapshots of thread CPU time with a small delay between them to calculate the current CPU usage rate,
 * similar to the 'top' command.
 */
public class ThreadCpuUsageOperation implements TroubleshootingOperation {

  public static final String THREAD_CPU_USAGE_OPERATION_NAME = "threadCpuUsage";
  public static final String THREAD_CPU_USAGE_OPERATION_DESCRIPTION = "Collects current CPU usage information for all threads";

  private static final TroubleshootingOperationDefinition definition = createOperationDefinition();

  @Override
  public TroubleshootingOperationDefinition getDefinition() {
    return definition;
  }

  @Override
  public TroubleshootingOperationCallback getCallback() {
    return (arguments, writer) -> {
      try {
        ThreadMXBean threadMXBean = getThreadMXBean();

        // Check if CPU time measurement is supported
        if (!threadMXBean.isThreadCpuTimeSupported()) {
          writer.write("Thread CPU time measurement is not supported on this JVM" + lineSeparator());
          return;
        }

        // Enable CPU time measurement if not already enabled
        if (!threadMXBean.isThreadCpuTimeEnabled()) {
          threadMXBean.setThreadCpuTimeEnabled(true);
        }

        // Take first snapshot
        long[] threadIds = threadMXBean.getAllThreadIds();
        Map<Long, Long> firstSnapshot = new HashMap<>();
        for (long threadId : threadIds) {
          long cpuTime = threadMXBean.getThreadCpuTime(threadId);
          if (cpuTime != -1) {
            firstSnapshot.put(threadId, cpuTime);
          }
        }

        // Wait for a short period to measure CPU usage
        long sampleMillis = 3000; // 3 second sample window
        sleep(sampleMillis);

        // Take second snapshot
        Map<Long, Long> secondSnapshot = new HashMap<>();
        for (long threadId : threadIds) {
          long cpuTime = threadMXBean.getThreadCpuTime(threadId);
          if (cpuTime != -1) {
            secondSnapshot.put(threadId, cpuTime);
          }
        }

        // Calculate CPU usage for each thread
        List<ThreadCpuInfo> threadCpuInfos = new ArrayList<>();
        ThreadInfo[] threadInfos = threadMXBean.getThreadInfo(threadIds, 0);

        // Calculate CPU percentage normalization
        int cpuCores = getRuntime().availableProcessors();
        double elapsedNanos = sampleMillis * 1_000_000.0; // Convert sample time to nanoseconds

        for (int i = 0; i < threadIds.length; i++) {
          long threadId = threadIds[i];
          ThreadInfo threadInfo = threadInfos[i];

          if (threadInfo != null && firstSnapshot.containsKey(threadId) && secondSnapshot.containsKey(threadId)) {
            long cpuTimeDiff = secondSnapshot.get(threadId) - firstSnapshot.get(threadId);
            // Convert nanoseconds to milliseconds (as double for precision)
            double cpuTimeMs = cpuTimeDiff / 1_000_000.0;
            // Calculate CPU percentage (normalized by cores)
            double cpuPercent = (cpuTimeDiff / elapsedNanos) * 100.0 / cpuCores;

            threadCpuInfos.add(new ThreadCpuInfo(
                                                 threadId,
                                                 threadInfo.getThreadName(),
                                                 threadInfo.getThreadState().toString(),
                                                 cpuTimeMs,
                                                 cpuPercent,
                                                 secondSnapshot.get(threadId) / 1_000_000.0 // Total CPU time in ms
            ));
          }
        }

        // Filter out threads with no significant CPU usage (less than 0.005% which would show as 0.00)
        List<ThreadCpuInfo> activeThreads = new ArrayList<>();
        for (ThreadCpuInfo info : threadCpuInfos) {
          if (info.getCpuPercent() >= 0.005) {
            activeThreads.add(info);
          }
        }

        // Sort by CPU usage (descending)
        activeThreads.sort(comparingDouble(ThreadCpuInfo::getCpuUsageMs).reversed());

        // Format output
        writer.write("=== Thread CPU Usage (active threads only, sorted by current consumption) ===" + lineSeparator());
        writer.write(format(US, "%-10s %-7s %-60s %-15s %-20s %-20s%n",
                            "Thread ID", "%CPU", "Thread Name", "State", "CPU Time (3s)", "Total CPU Time"));
        writer.write(format(US, "%-10s %-7s %-60s %-15s %-20s %-20s%n",
                            "----------", "-------", "------------------------------------------------------------",
                            "---------------", "--------------------", "--------------------"));

        if (activeThreads.isEmpty()) {
          writer.write("No threads with significant CPU usage detected during the sampling period." + lineSeparator());
        } else {
          for (ThreadCpuInfo info : activeThreads) {
            writer.write(format(US, "%-10d %-7.2f %-60s %-15s %-20s %-20s%n",
                                info.getThreadId(),
                                info.getCpuPercent(),
                                truncate(info.getThreadName(), 60),
                                truncate(info.getState(), 15),
                                formatTime(info.getCpuUsageMs()),
                                formatTime(info.getTotalCpuTimeMs())));
          }
        }

        writer.write(lineSeparator());
        writer.write("CPU percentage is normalized by the number of CPU cores (" + cpuCores + " cores)." + lineSeparator());

      } catch (Exception e) {
        throw new RuntimeException("Failed to get thread CPU usage: " + e.getMessage(), e);
      }
    };
  }

  private static TroubleshootingOperationDefinition createOperationDefinition() {
    return new DefaultTroubleshootingOperationDefinition(THREAD_CPU_USAGE_OPERATION_NAME, THREAD_CPU_USAGE_OPERATION_DESCRIPTION);
  }

  private String truncate(String str, int maxLength) {
    if (str == null) {
      return "";
    }
    return str.length() > maxLength ? str.substring(0, maxLength - 3) + "..." : str;
  }

  private String formatTime(double timeMs) {
    if (timeMs < 1000) {
      return format(US, "%.2fms", timeMs);
    } else {
      return format(US, "%.2fs", timeMs / 1000.0);
    }
  }

  private static class ThreadCpuInfo {

    private final long threadId;
    private final String threadName;
    private final String state;
    private final double cpuUsageMs;
    private final double cpuPercent;
    private final double totalCpuTimeMs;

    public ThreadCpuInfo(long threadId, String threadName, String state, double cpuUsageMs, double cpuPercent,
                         double totalCpuTimeMs) {
      this.threadId = threadId;
      this.threadName = threadName;
      this.state = state;
      this.cpuUsageMs = cpuUsageMs;
      this.cpuPercent = cpuPercent;
      this.totalCpuTimeMs = totalCpuTimeMs;
    }

    public long getThreadId() {
      return threadId;
    }

    public String getThreadName() {
      return threadName;
    }

    public String getState() {
      return state;
    }

    public double getCpuUsageMs() {
      return cpuUsageMs;
    }

    public double getCpuPercent() {
      return cpuPercent;
    }

    public double getTotalCpuTimeMs() {
      return totalCpuTimeMs;
    }
  }
}
