/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.http2;

import static org.mule.service.http.test.netty.AllureConstants.HTTP_2;

import static java.util.concurrent.TimeUnit.MILLISECONDS;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.junit.internal.matchers.ThrowableCauseMatcher.hasCause;
import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.mule.runtime.api.lifecycle.CreateException;
import org.mule.runtime.api.tls.TlsContextFactory;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.http.api.Http2ProtocolConfig;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.domain.request.HttpRequestContext;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.runtime.http.api.server.HttpServerConfiguration;
import org.mule.runtime.http.api.server.ServerCreationException;
import org.mule.runtime.http.api.server.async.ResponseStatusCallback;
import org.mule.service.http.common.server.sse.FutureCompleterCallback;
import org.mule.service.http.test.common.AbstractHttpServiceTestCase;
import org.mule.service.http.test.common.client.sse.FixedSizeStream;
import org.mule.service.http.test.common.client.sse.ThrowingStream;
import org.mule.tck.junit5.DynamicPort;

import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

import io.qameta.allure.Feature;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

@Feature(HTTP_2)
class Http2ClientServerRequestsTestCase extends AbstractHttpServiceTestCase {

  private static final int RESP_TIMEOUT_MS = 4000;
  private static final int RESP_SIZE = 64 * 1024;

  @DynamicPort(systemProperty = "serverPort")
  Integer serverPort;

  private HttpServer httpServer;
  private HttpClient httpClient;

  private final CompletableFuture<HttpRequestContext> seenRequestCtx = new CompletableFuture<>();
  private final CompletableFuture<Void> responseStatusFuture = new CompletableFuture<>();
  private final ResponseStatusCallback responseStatusCallback = new FutureCompleterCallback(responseStatusFuture);
  private final Latch receivedRequest = new Latch();
  private final Latch readyToSendResponse = new Latch();

  public Http2ClientServerRequestsTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @BeforeEach
  void setUp() throws Exception {
    httpServer = createServer();
    httpClient = createClient();
  }

  @AfterEach
  void tearDown() {
    readyToSendResponse.release();
    httpClient.stop();
    httpServer.stop().dispose();
  }

  @Test
  void sendGet() throws ExecutionException, InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/test".formatted(serverPort))
        .build();

    httpClient.sendAsync(request).get();
    var requestInServer = seenRequestCtx.get().getRequest();
    assertThat(requestInServer.getMethod(), is("GET"));
    assertThat(requestInServer.getPath(), is("/test"));
    assertThat(requestInServer.getProtocol().asString(), is("HTTP/2"));
    assertDoesNotThrow(() -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS));
  }

  @Test
  void sendGetWithQueryParams() throws ExecutionException, InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/test".formatted(serverPort))
        .addQueryParam("queryParam1", "value1.1")
        .addQueryParam("queryParam1", "value1.2")
        .addQueryParam("queryParam2", "value2")
        .build();

    httpClient.sendAsync(request).get();
    var requestInServer = seenRequestCtx.get().getRequest();
    assertThat(requestInServer.getQueryParams().getAll("queryParam1"), contains("value1.1", "value1.2"));
    assertThat(requestInServer.getQueryParams().get("queryParam2"), is("value2"));
    assertDoesNotThrow(() -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS));
  }

  @Test
  void sendGetWithHeaders() throws ExecutionException, InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/test".formatted(serverPort))
        .addHeader("header1", "value1")
        .addHeader("header2", "value2")
        .build();

    httpClient.sendAsync(request).get();
    var requestInServer = seenRequestCtx.get().getRequest();
    assertThat(requestInServer.getHeaders().get("header1"), is("value1"));
    assertThat(requestInServer.getHeaders().get("header2"), is("value2"));
    assertDoesNotThrow(() -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS));
  }

  @Test
  void sendGetWithMultimapHeaders() throws ExecutionException, InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/test".formatted(serverPort))
        .addHeader("header1", "value1.1")
        .addHeader("header1", "value1.2")
        .addHeader("header2", "value2")
        .build();

    httpClient.sendAsync(request).get();
    var requestInServer = seenRequestCtx.get().getRequest();
    assertThat(requestInServer.getHeaders().getAll("header1"), contains("value1.1", "value1.2"));
    assertThat(requestInServer.getHeaders().get("header2"), is("value2"));
    assertDoesNotThrow(() -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS));
  }

  @Test
  void whenSendingResponseFailsFlushingThenCallbackIsCalled() throws InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/testLatched".formatted(serverPort))
        .build();

    // Sends a request with a short response timeout, making sure we fail
    CompletableFuture<HttpResponse> resp =
        httpClient.sendAsync(request, HttpRequestOptions.builder().responseTimeout(100).build());

    // Makes sure the request has been received by the server
    // (The server will wait for our signal before sending the response)
    receivedRequest.await();

    // Wait for the response timeout (and make sure it happens)
    assertThat(assertThrows(ExecutionException.class, resp::get), hasCause(instanceOf(TimeoutException.class)));

    // Signals the server to try sending the response now (it should fail)
    readyToSendResponse.release();
    assertThat(assertThrows(ExecutionException.class, () -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS)),
               hasCause(instanceOf(ClosedChannelException.class)));
  }

  @Test
  void whenSendingResponseFailsBeforeWriteThenCallbackIsCalled() throws InterruptedException {
    var request = HttpRequest.builder()
        .uri("https://localhost:%d/testStreamThrows".formatted(serverPort))
        .build();

    httpClient.sendAsync(request);

    // Makes sure the request has been received by the server
    // (The server will wait for our signal before sending the response)
    receivedRequest.await();

    // Signals the server to try sending the response now (it should fail)
    readyToSendResponse.release();
    assertThat(assertThrows(ExecutionException.class, () -> responseStatusFuture.get(RESP_TIMEOUT_MS, MILLISECONDS)),
               hasCause(hasMessage(is("Kaboom!"))));
  }

  private HttpClient createClient() throws CreateException {
    var client = service.getClientFactory().create(new HttpClientConfiguration.Builder()
        .setName("HTTP/2 Client")
        .setHttp2Config(new Http2ProtocolConfig(true))
        .setTlsContextFactory(TlsContextFactory.builder().trustStorePath("trustStore")
            .trustStorePassword("mulepassword").insecureTrustStore(true).build())
        .build());
    client.start();
    return client;
  }

  private HttpServer createServer()
      throws ServerCreationException, IOException, CreateException {
    var server = service.getServerFactory().create(new HttpServerConfiguration.Builder()
        .setName("HTTP/2 Server")
        .setHost("localhost")
        .setPort(serverPort)
        .setHttp2Config(new Http2ProtocolConfig(true))
        .setTlsContextFactory(TlsContextFactory.builder()
            .keyStorePath("serverKeystore")
            .keyStorePassword("mulepassword").keyAlias("muleserver").keyPassword("mulepassword").keyStoreAlgorithm("PKIX")
            .build())
        .build());

    server.start();
    server.addRequestHandler("/test", (ctx, callback) -> {
      seenRequestCtx.complete(ctx);
      callback.responseReady(HttpResponse.builder().build(), responseStatusCallback);
    });

    server.addRequestHandler("/testLatched", (ctx, callback) -> {
      seenRequestCtx.complete(ctx);
      receivedRequest.release();
      try {
        readyToSendResponse.await();
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }

      HttpResponse resp = HttpResponse.builder()
          .entity(new InputStreamHttpEntity(new FixedSizeStream(RESP_SIZE)))
          .build();
      callback.responseReady(resp, responseStatusCallback);
    });

    server.addRequestHandler("/testStreamThrows", (ctx, callback) -> {
      seenRequestCtx.complete(ctx);
      receivedRequest.release();
      try {
        readyToSendResponse.await();
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }

      HttpResponse resp = HttpResponse.builder()
          .entity(new InputStreamHttpEntity(new ThrowingStream()))
          .build();
      callback.responseReady(resp, responseStatusCallback);
    });

    return server;
  }
}
