/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static java.lang.Math.min;
import static java.lang.System.nanoTime;
import static java.lang.Thread.currentThread;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

import static org.slf4j.LoggerFactory.getLogger;

import java.util.concurrent.TimeUnit;

import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;

/**
 * Keeps track of the amount of active connections, and has a method to {@link #waitForConnectionsToBeClosed(long, TimeUnit) wait
 * for the counter to become zero}.
 */
@Sharable
public class ConnectionsCounterHandler extends ChannelInboundHandlerAdapter {

  private static final Logger LOGGER = getLogger(ConnectionsCounterHandler.class);

  private int connectionsCount;

  @Override
  public void channelActive(ChannelHandlerContext ctx) throws Exception {
    incrementConnectionsCount();
    super.channelActive(ctx);
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    super.channelInactive(ctx);
    decrementConnectionsCount();
  }

  private synchronized void incrementConnectionsCount() {
    this.connectionsCount += 1;
  }

  private synchronized void decrementConnectionsCount() {
    this.connectionsCount -= 1;
    notifyAll();
  }

  public void waitForConnectionsToBeClosed(long timeout, TimeUnit unit) {
    if (timeout == 0) {
      return;
    }

    final long stopNanos = nanoTime() + unit.toNanos(timeout);
    synchronized (this) {
      try {
        long remainingNanos = stopNanos - nanoTime();
        while (connectionsCount > 0 && remainingNanos > 0) {
          long remainingMillis = NANOSECONDS.toMillis(remainingNanos);
          long millisToWait = min(remainingMillis, 50);
          LOGGER.debug("There are still {} open connections on server stop. Waiting {} milliseconds", connectionsCount,
                       millisToWait);
          if (millisToWait > 0) {
            wait(millisToWait);
          } else {
            wait(0, (int) remainingNanos);
          }
          remainingNanos = stopNanos - nanoTime();
        }
      } catch (InterruptedException e) {
        currentThread().interrupt();
      }
      if (connectionsCount > 0) {
        LOGGER.warn("There are still {} open connections on server stop.", connectionsCount);
      }
    }
  }
}
