/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.client;

import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;

import org.mule.runtime.http.api.domain.entity.EmptyHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.netty.impl.client.NettyHttpClient;
import org.mule.service.http.netty.impl.client.ReactorNettyClient;
import org.mule.service.http.test.common.AbstractHttpTestCase;
import org.mule.tck.junit4.rule.DynamicPort;
import org.mule.tck.junit4.rule.SystemProperty;

import java.lang.reflect.Field;
import java.security.cert.CertificateException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;

import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.timeout.IdleStateHandler;
import io.qameta.allure.Issue;
import org.apache.commons.io.IOUtils;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.netty.DisposableServer;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.server.HttpServer;

public class NettyHttpClientTLSRenegotiateTestCase
    extends AbstractHttpTestCase {

  @Rule
  public final SystemProperty allowUnsafeCertChange = new SystemProperty("jdk.tls.allowUnsafeServerCertChange", "true");
  @Rule
  public final SystemProperty allowLegacyHelloMessages = new SystemProperty("jdk.tls.allowLegacyHelloMessages", "true");
  @Rule
  public final SystemProperty rejectClientInitiatedRenegotiation =
      new SystemProperty("jdk.tls.rejectClientInitiatedRenegotiation", "false");
  @Rule
  public DynamicPort serverPort = new DynamicPort("serverPort");

  // TODO W-17900306 [Netty] [tech debt] Migrate netty server to wiremock server in unit test
  private NettyHttpClient httpClientUsingTLSv12;
  private NettyHttpClient httpClientUsingTLSv13;

  @Before
  public void setUp() throws Exception {

    httpClientUsingTLSv12 = NettyHttpClient.builder().withConnectionIdleTimeout(10000).withSslContext(
                                                                                                      SslContextBuilder
                                                                                                          .forClient()
                                                                                                          .sslProvider(SslProvider.JDK)
                                                                                                          .trustManager(InsecureTrustManagerFactory.INSTANCE)
                                                                                                          .protocols("TLSv1.2")
                                                                                                          .build())
        .withUsingPersistentConnections(true).build();

    httpClientUsingTLSv13 = NettyHttpClient.builder().withConnectionIdleTimeout(10000).withSslContext(
                                                                                                      SslContextBuilder
                                                                                                          .forClient()
                                                                                                          .sslProvider(SslProvider.JDK)
                                                                                                          .trustManager(InsecureTrustManagerFactory.INSTANCE)
                                                                                                          .protocols("TLSv1.3")
                                                                                                          .build())
        .withUsingPersistentConnections(true).build();

    httpClientUsingTLSv12.start();
    httpClientUsingTLSv13.start();
  }


  @Test
  @Issue("W-17849284")
  public void testTLSv1_2RenegotiationShouldBeAttempted() throws Exception {

    DisposableServer server = null;

    try {
      server = HttpServer.create().port(serverPort.getNumber()).secure(spec -> {
        try {
          spec.sslContext(buildSslContext("TLSv1.2"));
        } catch (CertificateException | SSLException e) {
          throw new RuntimeException(e);
        }
      }).doOnConnection(conn -> conn.addHandlerLast(new IdleStateHandler(0, 0, 30))).route(routes -> routes.get("/test-tls-v1_2",
                                                                                                                (request,
                                                                                                                 response) -> response
                                                                                                                     .header(
                                                                                                                             "Connection",
                                                                                                                             "keep-alive")
                                                                                                                     .sendString(
                                                                                                                                 Mono.just(
                                                                                                                                           "renegotiation"))))
          .bindNow();

      String serverUrl = format("https://localhost:%d/test-tls-v1_2", serverPort.getNumber());
      HttpRequest httpRequest = HttpRequest.builder().uri(serverUrl).method("GET").entity(new EmptyHttpEntity()).build();
      HttpResponse response = httpClientUsingTLSv12.sendAsync(httpRequest).get();

      assertThat(response, is(notNullValue()));
      String responseBody = IOUtils.toString(response.getEntity().getContent(), UTF_8);
      assertThat(responseBody, containsString("renegotiation"));

      CountDownLatch latch = new CountDownLatch(1);
      AtomicReference<Throwable> renegotiationError = new AtomicReference<>();

      getHttpClient(getReactorNettyClient(httpClientUsingTLSv12)).doOnConnected(conn -> {
        try {
          Channel channel = conn.channel();
          ChannelPipeline pipeline = channel.pipeline();
          SslHandler sslHandler = pipeline.get(SslHandler.class);

          if (sslHandler != null) {
            sslHandler.renegotiate();
          } else {
            renegotiationError.set(new IllegalStateException("SslHandler not found in pipeline"));
          }

        } catch (Throwable t) {
          renegotiationError.set(t);
        } finally {
          latch.countDown();
        }
      }).baseUrl(serverUrl).get().response().doOnTerminate(latch::countDown).subscribe(r -> {
      }, error -> renegotiationError.compareAndSet(null, error));

      latch.await(10, TimeUnit.SECONDS);
      Throwable renegotiationThrowable = renegotiationError.get();

      if (renegotiationThrowable != null) {
        Throwable cause = renegotiationThrowable.getCause() != null ? renegotiationThrowable.getCause() : renegotiationThrowable;
        // we are asserting that a renegotiation was attempted but this produces an exception due to jdk configuration
        assertThat((cause instanceof SSLHandshakeException), is(true));
        assertThat(cause.getMessage() != null && cause.getMessage().contains("no_renegotiation"), is(true));
      }

    } finally {
      if (server != null) {
        server.disposeNow();
      }
    }
  }

  @Test
  @Issue("W-17849285")
  public void testTLSv1_3RenegotiationShouldNotBeAttempted() throws Exception {

    DisposableServer server = null;

    try {
      server = HttpServer.create().port(serverPort.getNumber()).secure(spec -> {
        try {
          spec.sslContext(buildSslContext("TLSv1.3"));
        } catch (CertificateException | SSLException e) {
          throw new RuntimeException(e);
        }
      }).doOnConnection(conn -> conn.addHandlerLast(new IdleStateHandler(0, 0, 30))).route(routes -> routes.get("/test-tls-v1_3",
                                                                                                                (request,
                                                                                                                 response) -> response
                                                                                                                     .header(
                                                                                                                             "Connection",
                                                                                                                             "keep-alive")
                                                                                                                     .sendString(
                                                                                                                                 Mono.just(
                                                                                                                                           "no renegotiation"))))
          .bindNow();

      String serverUrl = format("https://localhost:%d/test-tls-v1_3", serverPort.getNumber());
      HttpRequest httpRequest = HttpRequest.builder().uri(serverUrl).method("GET").entity(new EmptyHttpEntity()).build();
      HttpResponse response = httpClientUsingTLSv13.sendAsync(httpRequest).get();

      assertThat(response, is(notNullValue()));
      assertThat(IOUtils.toString(response.getEntity().getContent(), UTF_8), containsString("no renegotiation"));

      CountDownLatch latch = new CountDownLatch(1);
      AtomicReference<Throwable> renegotiationFailure = new AtomicReference<>();
      AtomicReference<SSLEngineResult.HandshakeStatus> status = new AtomicReference<>();

      getHttpClient(getReactorNettyClient(httpClientUsingTLSv13)).doOnConnected(conn -> {
        try {
          Channel channel = conn.channel();
          ChannelPipeline pipeline = channel.pipeline();
          SslHandler sslHandler = pipeline.get(SslHandler.class);
          SSLEngine sslEngine = sslHandler.engine();

          sslHandler.renegotiate().addListener(renegotiateFuture -> {
            try {
              if (renegotiateFuture.isSuccess()) {
                status.set(sslEngine.getHandshakeStatus());
              } else {
                renegotiationFailure.set(renegotiateFuture.cause());
              }
            } catch (Throwable t) {
              renegotiationFailure.set(t);
            } finally {
              latch.countDown();
            }
          });

        } catch (Throwable t) {
          renegotiationFailure.set(t);
          latch.countDown();
        }
      }).baseUrl(serverUrl).get().response().doOnTerminate(latch::countDown).subscribe();

      boolean completed = latch.await(10, TimeUnit.SECONDS);
      assertThat(completed, is(true));

      Throwable renegotiationThrowable = renegotiationFailure.get();
      if (renegotiationThrowable != null) {
        throw new AssertionError("Renegotiation assertion failed", renegotiationThrowable);
      }

      assertThat(status.get(), is(not(NEED_UNWRAP)));
      assertThat(status.get(), is(NOT_HANDSHAKING));

    } finally {
      if (server != null) {
        server.disposeNow();
      }
    }
  }

  private SslContext buildSslContext(String protocol) throws CertificateException, SSLException {
    SelfSignedCertificate cert = new SelfSignedCertificate();
    return SslContextBuilder.forServer(cert.certificate(), cert.privateKey()).protocols(protocol).build();
  }

  private ReactorNettyClient getReactorNettyClient(NettyHttpClient client) throws Exception {
    Field field = NettyHttpClient.class.getDeclaredField("reactorNettyClient");
    field.setAccessible(true);
    return (ReactorNettyClient) field.get(client);
  }

  private HttpClient getHttpClient(ReactorNettyClient reactorNettyClient) throws Exception {
    Field field = ReactorNettyClient.class.getDeclaredField("httpClient");
    field.setAccessible(true);
    return (HttpClient) field.get(reactorNettyClient);
  }
}
