/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.tck.junit5;

import static java.lang.Thread.currentThread;
import static java.lang.reflect.Modifier.isStatic;

import static org.slf4j.LoggerFactory.getLogger;

import java.io.Closeable;
import java.lang.reflect.Field;
import java.util.concurrent.ConcurrentHashMap;

import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;

/**
 * JUnit 5 extension that automatically handles classloader cleanup for test classes.
 *
 * <p>
 * This extension provides automatic cleanup focused specifically on classloader management:
 * <ul>
 * <li>Closeable ClassLoaders (close to release resources and prevent leaks)</li>
 * <li>Context classloader restoration to original state</li>
 * </ul>
 *
 * <p>
 * The extension works by:
 * <ol>
 * <li>Storing the original context classloader before test execution</li>
 * <li>After test execution: cleaning up annotated classloaders and restoring the original context classloader</li>
 * </ol>
 *
 * <p>
 * Fields annotated with {@link ClassLoaderCleanup} will be automatically closed if they implement {@link Closeable}. Unlike
 * resource injection patterns, this extension cleans up existing field values rather than injecting new ones.
 *
 * <p>
 * Supports both instance fields (cleaned after each test method) and static fields (cleaned after all test methods in the class).
 */
public class ClassLoaderCleanupExtension implements BeforeEachCallback, AfterEachCallback, BeforeAllCallback, AfterAllCallback {

  private static final Logger LOGGER = getLogger(ClassLoaderCleanupExtension.class);

  // Store original classloaders per test instance to handle parallel test execution
  private final ConcurrentHashMap<Object, ClassLoader> originalInstanceClassLoaders = new ConcurrentHashMap<>();

  // Store original classloader per test class for static fields
  private final ConcurrentHashMap<Class<?>, ClassLoader> originalStaticClassLoaders = new ConcurrentHashMap<>();

  @Override
  public void beforeEach(ExtensionContext context) throws Exception {
    // Store the original classloader for this test instance
    Object testInstance = context.getTestInstance().orElse(null);
    if (testInstance != null) {
      originalInstanceClassLoaders.put(testInstance, currentThread().getContextClassLoader());
      LOGGER.debug("Stored original context classloader for test: {}", context.getDisplayName());
    }
  }

  @Override
  public void afterEach(ExtensionContext context) throws Exception {
    Object testInstance = context.getTestInstance().orElse(null);

    if (testInstance == null) {
      return;
    }

    try {
      // Clean up annotated instance classloader fields
      cleanupClassLoaderFields(testInstance.getClass(), testInstance, false);
    } finally {
      // Always restore the original classloader, even if cleanup fails
      ClassLoader originalClassLoader = originalInstanceClassLoaders.remove(testInstance);
      if (originalClassLoader != null) {
        currentThread().setContextClassLoader(originalClassLoader);
        LOGGER.debug("Restored original context classloader for test: {}", context.getDisplayName());
      }
    }
  }

  @Override
  public void beforeAll(ExtensionContext context) throws Exception {
    // Store the original classloader for static fields in this test class
    Class<?> testClass = context.getRequiredTestClass();
    originalStaticClassLoaders.put(testClass, currentThread().getContextClassLoader());
    LOGGER.debug("Stored original context classloader for test class: {}", testClass.getSimpleName());
  }

  @Override
  public void afterAll(ExtensionContext context) throws Exception {
    Class<?> testClass = context.getRequiredTestClass();

    try {
      // Clean up annotated static classloader fields
      cleanupClassLoaderFields(testClass, null, true);
    } finally {
      // Always restore the original classloader, even if cleanup fails
      ClassLoader originalClassLoader = originalStaticClassLoaders.remove(testClass);
      if (originalClassLoader != null) {
        currentThread().setContextClassLoader(originalClassLoader);
        LOGGER.debug("Restored original context classloader for test class: {}", testClass.getSimpleName());
      }
    }
  }


  /**
   * Clean up classloader fields for a specific test class.
   *
   * @param testClass     the test class to process
   * @param testInstance  the test instance for instance fields, null for static fields
   * @param processStatic true to process static fields, false for instance fields
   */
  private void cleanupClassLoaderFields(Class<?> testClass, Object testInstance, boolean processStatic) {
    try {
      Class<?> currentClass = testClass;
      while (currentClass != null && currentClass != Object.class) {
        for (Field field : currentClass.getDeclaredFields()) {
          if (shouldCleanupField(field, processStatic)) {
            cleanupField(testInstance, field);
          }
        }
        currentClass = currentClass.getSuperclass();
      }
    } catch (Exception e) {
      LOGGER.warn("Error during classloader field cleanup", e);
    }
  }

  private boolean shouldCleanupField(Field field, boolean processStatic) {
    return field.isAnnotationPresent(ClassLoaderCleanup.class) && isStatic(field.getModifiers()) == processStatic;
  }

  private void cleanupField(Object testInstance, Field field) {
    field.setAccessible(true);
    try {
      Object fieldValue = field.get(testInstance);
      if (fieldValue instanceof Closeable closeable) {
        closeable.close();
        LOGGER.debug("Closed annotated Closeable field '{}': {}",
                     field.getName(),
                     fieldValue.getClass().getSimpleName());
      } else if (fieldValue != null) {
        LOGGER.debug("Skipped non-Closeable field '{}': {}",
                     field.getName(),
                     fieldValue.getClass().getSimpleName());
      }
    } catch (Exception e) {
      LOGGER.warn("Failed to cleanup annotated classloader field: {}", field.getName(), e);
    }
  }
}
