/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.proxy;

import static org.mule.service.http.test.common.util.HttpResponseStatusMatcher.hasStatusCode;

import static java.nio.charset.StandardCharsets.UTF_8;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.any;
import static com.github.tomakehurst.wiremock.client.WireMock.anyRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.exactly;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;

import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.client.proxy.AuthHeaderFactory;
import org.mule.runtime.http.api.client.proxy.ProxyConfig;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.test.common.AbstractHttpServiceTestCase;
import org.mule.tck.junit5.DynamicPort;

import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import com.github.tomakehurst.wiremock.WireMockServer;

import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/**
 * Tests custom proxy authentication with preemptive auth flow.
 */
public class CustomProxyAuthPreemptiveTestCase extends AbstractHttpServiceTestCase {

  private static final String TEST_JSON = "{'method': 'text'}";

  @DynamicPort(systemProperty = "proxyPort")
  Integer proxyPort;
  @DynamicPort(systemProperty = "targetPort")
  Integer targetPort;

  private HttpClient httpClient;
  private WireMockServer proxyServer;
  private WireMockServer targetServer;

  public CustomProxyAuthPreemptiveTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @BeforeEach
  void setUp() throws Exception {
    // start target server
    targetServer = new WireMockServer(wireMockConfig().port(targetPort));
    targetServer.start();
    targetServer.stubFor(get(urlEqualTo("/target-endpoint"))
        .willReturn(aResponse()
            .withStatus(200)
            .withHeader("Content-Type", "application/json")
            .withBody(TEST_JSON)));

    // start proxy server
    proxyServer = new WireMockServer(wireMockConfig()
        .port(proxyPort)
        .enableBrowserProxying(true));
    proxyServer.start();

    // configure proxy
    proxyServer.stubFor(any(anyUrl())
        .willReturn(aResponse()
            .withStatus(407)
            .withHeader("Proxy-Authenticate", "Basic realm=\"Test Proxy\"")));

    proxyServer.stubFor(any(anyUrl())
        .withHeader("Proxy-Authorization", matching("Basic dGVzdDp0ZXN0"))
        .willReturn(aResponse()
            .proxiedFrom("http://localhost:" + targetPort)));

    AuthHeaderFactory customAuthFactory = new AuthHeaderFactory() {

      private boolean finished = false;

      @Override
      public boolean hasFinished() {
        return finished;
      }

      @Override
      public CompletableFuture<String> generateHeader(HttpResponse response) {
        return CompletableFuture.supplyAsync(() -> {
          // if already finished, don't provide more headers
          if (finished) {
            return null;
          }

          // Preemptive: provide auth immediately for preemptive requests (response == null)
          if (response == null) {
            finished = true;
            return "Basic dGVzdDp0ZXN0"; // test:test in base64
          }

          // handle challenge-response as well (response != null && statusCode == 407)
          if (response.getStatusCode() == 407 && response.getHeaders() != null) {
            Collection<String> authHeaders = response.getHeaders().getAll("Proxy-Authenticate");
            if (authHeaders != null && authHeaders.stream()
                .anyMatch(header -> header != null && header.toLowerCase().startsWith("basic"))) {
              finished = true;
              return "Basic dGVzdDp0ZXN0"; // test:test in base64
            }
          }

          return null;
        });
      }
    };

    ProxyConfig customProxyConfig = new ProxyConfig() {

      @Override
      public String getHost() {
        return "localhost";
      }

      @Override
      public int getPort() {
        return proxyPort;
      }

      @Override
      public String getUsername() {
        return null;
      }

      @Override
      public String getPassword() {
        return null;
      }

      @Override
      public String getNonProxyHosts() {
        return "";
      }

      @Override
      public AuthHeaderFactory getAuthHeaderFactory() {
        return customAuthFactory;
      }
    };

    httpClient = service.getClientFactory().create(new HttpClientConfiguration.Builder()
        .setName("Client")
        .setProxyConfig(customProxyConfig)
        .setStreaming(true)
        .build());
    httpClient.start();
  }

  @AfterEach
  void tearDown() {
    httpClient.stop();
    if (proxyServer != null) {
      proxyServer.stop();
    }
    if (targetServer != null) {
      targetServer.stop();
    }
  }

  @Test
  void proxyHttpTest() throws ExecutionException, InterruptedException, IOException {
    String uri = "http://localhost:" + targetPort + "/target-endpoint"; // request goes to target, routed via proxy

    HttpRequest request = HttpRequest.builder()
        .uri(uri)
        .build();

    var future = httpClient.sendAsync(request);

    HttpResponse response = future.get();

    // 1) verify response body
    assertThat(response, hasStatusCode(200));
    String responseBody = IOUtils.toString(response.getEntity().getContent(), UTF_8);
    assertThat(responseBody, containsString("'method': 'text'"));

    // 2) verify preemptive auth flow (only 1 request needed)
    proxyServer.verify(exactly(1), anyRequestedFor(anyUrl()));

    // 3) verify that at least one request had the Proxy-Authorization header
    proxyServer.verify(anyRequestedFor(anyUrl())
        .withHeader("Proxy-Authorization", matching("Basic dGVzdDp0ZXN0")));

    // 4) verify that target server received the request
    targetServer.verify(getRequestedFor(urlEqualTo("/target-endpoint")));
  }

}
