/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.client;

import static org.mule.runtime.api.util.DataUnit.KB;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.service.http.test.netty.AllureConstants.HttpStory.STREAMING;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static java.lang.Thread.currentThread;
import static java.util.UUID.randomUUID;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.test.common.client.sse.FillAndWaitStream;
import org.mule.service.http.test.common.util.ResponseReceivedProbe;
import org.mule.tck.probe.PollingProber;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.qameta.allure.Description;
import io.qameta.allure.Story;
import org.slf4j.MDC;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;

@Story(STREAMING)
public class HttpClientMdcStreamingTestCase extends AbstractHttpClientTestCase {

  private static final int RESPONSE_TIMEOUT = 3000;
  private static final int TIMEOUT_MILLIS = 1000;
  private static final int POLL_DELAY_MILLIS = 200;

  private static Latch latch;

  private Latch beforeResponseLatch;
  private AtomicBoolean serverShouldThrowException;
  private HttpClientConfiguration.Builder clientBuilder = new HttpClientConfiguration.Builder().setName("streaming-test");
  private PollingProber pollingProber = new PollingProber(TIMEOUT_MILLIS, POLL_DELAY_MILLIS);

  public HttpClientMdcStreamingTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @BeforeEach
  public void createLatch() {
    latch = new Latch();
  }

  @Test
  @Description("Uses a streaming HTTP client to send a non blocking request and asserts that MDC values are propagated to the response handler when the request fails due to timeout.")
  void nonBlockingStreamingMDCPropagationOnError() throws Exception {
    testMdcPropagation(true, true);
  }

  @Test
  @Description("Uses a non streaming HTTP client to send a non blocking request and asserts that MDC values are propagated to the response handler when the request fails due to timeout.")
  void nonBlockingNoStreamingMDCPropagationOnError() throws Exception {
    testMdcPropagation(false, true);
  }

  @Test
  @Description("Uses a streaming HTTP client to send a non blocking request and asserts that MDC values are propagated to the response handler when the request is executed successfully.")
  void nonBlockingStreamingMDCPropagationNoError() throws Exception {
    testMdcPropagation(true, false);
  }

  @Test
  @Description("Uses a non streaming HTTP client to send a non blocking request and asserts that MDC values are propagated to the response handler when the request is executed successfully.")
  void nonBlockingNoStreamingMDCPropagationNoError() throws Exception {
    testMdcPropagation(false, false);
  }

  protected HttpRequest getRequest(String uri) {
    return HttpRequest.builder().uri(uri).build();
  }

  protected HttpRequest getRequest() {
    return getRequest(getUri());
  }

  @Override
  protected HttpResponse setUpHttpResponse(HttpRequest request) {
    if (Objects.nonNull(beforeResponseLatch)) {
      try {
        beforeResponseLatch.await();
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
    }

    if (serverShouldThrowException != null && serverShouldThrowException.get()) {
      throw new RuntimeException("Forced Timeout");
    } else {
      return HttpResponse.builder().statusCode(OK.getStatusCode()).reasonPhrase(OK.getReasonPhrase())
          .entity(new InputStreamHttpEntity(new FillAndWaitStream(latch))).build();
    }
  }

  @Override
  protected HttpRequestOptions getDefaultOptions(int responseTimeout) {
    return HttpRequestOptions.builder().responseTimeout(responseTimeout).build();
  }

  private void testMdcPropagation(boolean shouldStream, boolean shouldThrowException) throws IOException {
    serverShouldThrowException = new AtomicBoolean(shouldThrowException);
    beforeResponseLatch = new Latch();
    HttpClient client = service
        .getClientFactory()
        .create(clientBuilder.setResponseBufferSize(KB.toBytes(10)).setStreaming(shouldStream).build());
    client.start();
    final Reference<HttpResponse> responseReference = new Reference<>();
    if (!shouldThrowException) {
      // Release the lock on the streaming payload in order to finish without throwing an exception.
      latch.release();
    }
    String transactionId = randomUUID().toString();
    Map<String, Object> capture = new HashMap<>();
    MDC.put("transactionId", transactionId);
    MDC.put("currentThread", currentThread().getName());
    try {
      client.sendAsync(getRequest(),
                       getDefaultOptions(RESPONSE_TIMEOUT))
          .whenComplete(
                        (response, exception) -> {
                          if (shouldThrowException) {
                            responseReference
                                .set(HttpResponse.builder().statusCode(INTERNAL_SERVER_ERROR.getStatusCode()).build());
                          } else {
                            responseReference.set(response);
                          }
                          // since assertions won't fail here we need to capture this variables to assert later on.
                          capture.put("exception", exception);
                          capture.put("transactionId", MDC.get("transactionId"));
                          capture.put("currentThread", currentThread().getName());
                        });

      beforeResponseLatch.release();
      pollingProber.check(new ResponseReceivedProbe(responseReference));
      assertThat(capture.get("exception"), shouldThrowException ? notNullValue() : nullValue());
      assertThat(MDC.get("transactionId"), is(transactionId));
      assertThat(capture.get("transactionId"), is(transactionId));
      assertThat(MDC.get("currentThread"), is(not(capture.get("currentThread"))));
      assertThat(responseReference.get().getStatusCode(),
                 shouldThrowException ? is(INTERNAL_SERVER_ERROR.getStatusCode()) : is(OK.getStatusCode()));
    } finally {
      client.stop();
      beforeResponseLatch = null;
      serverShouldThrowException = null;
    }
  }

}
