/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.auth.ntlm;

import static java.nio.charset.StandardCharsets.US_ASCII;

import static org.slf4j.LoggerFactory.getLogger;

import org.mule.service.http.netty.impl.client.auth.AuthHeaderFactory;
import org.mule.service.http.netty.impl.client.auth.ntlm.message.Type1Message;
import org.mule.service.http.netty.impl.client.auth.ntlm.message.Type2Message;
import org.mule.service.http.netty.impl.client.auth.ntlm.message.Type3Message;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.util.Base64;
import java.util.Locale;

import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;

import org.slf4j.Logger;

/**
 * Implementation of the client-side NTLM message generation.
 * <p>
 * The NTLM authentication is done once per connection, and the mechanism can be summarized as: <il>
 * <li>1. The client sends a request without any kind of authentication</li>
 * <li>2. The server responds with a 401 error code, including a WWW-AUTHENTICATE header with the value "NTLM".</li>
 * <li>3. The client sends another request over the same connection, now containing a Type1 message in the AUTHORIZATION
 * header.</li>
 * <li>4. The server responds with a 401 error code, including a Type2 message in the WWW-AUTHENTICATE header.</li>
 * <li>5. The client sends a last request with the Type3 message in the AUTHORIZATION header.</li> </il> This class has methods to
 * generate the client-side messages (Type1 and Type3) and to create the corresponding AUTHORIZATION header values given a
 * WWW-AUTHENTICATE header value.
 */
public class NtlmMessageFactory implements AuthHeaderFactory {


  // Status of the authentication, to avoid falling into bad states (as infinite recursions) if the server is not well
  // implemented.
  enum Status {

    // Initial state, no headers exchanged.
    NOT_STARTED,

    // In some auth methods there are multiple messages exchanged, for example 401 received,
    // we sent a message, and we're waiting for the challenge from the server
    WAITING_FOR_CHALLENGE,

    // Last message was responded and even if this one failed there is nothing to be done
    FINISHED,
  }

  private static final Logger LOGGER = getLogger(NtlmMessageFactory.class);

  // Same flags as we used in the legacy AHC implementation.
  private static final int TYPE_1_MESSAGE_FLAGS = 0xA2088201;
  // Constant copied from the legacy AHC implementation.
  private static final byte[] MAGIC_CONSTANT = "KGS!@#$%".getBytes(US_ASCII);
  private static final String NTLM_MESSAGES_PREFIX = "NTLM ";
  // Value sent by server in the WWW-AUTHENTICATE to indicate that NTLM Type1 message is required.
  private static final String STARTING_NTLM_WWW_AUTHENTICATE_HEADER = "NTLM";

  private final String domain;
  private final String workstation;
  private final String username;
  private final String password;

  private Status status;

  public NtlmMessageFactory(String domain, String workstation, String username, String password) {
    this.domain = domain;
    this.workstation = workstation;
    this.username = username;
    this.password = password;
    this.status = Status.NOT_STARTED;
  }

  /**
   * Generates a raw NTLM Type 1 message. The flags used for this are the same as the ones present in our old implementation using
   * Grizzly AHC. The only thing that changes is the NTLM version.
   * 
   * @return A raw representation of the Type 1 message (without Base64 encoding nor the "NTLM " prefix needed in the header).
   */
  public byte[] createType1Message() {
    return RawType1MessageHolder.RAW_TYPE_1_MESSAGE;
  }

  /**
   * Generates a raw NTLM Type 3 message. The flags used for this are the same as the ones present in the Type 2 message passed as
   * parameter.
   * 
   * @param type2Material a raw representation of the Type 2 message received from the server.
   * @return A raw representation of the Type 3 message (without Base64 encoding nor the "NTLM " prefix needed in the header).
   * @throws IOException              If an error occurs while parsing the passed type 2 message material, or creating the type 3
   *                                  message.
   * @throws GeneralSecurityException If an error occurs while creating the type 3 message.
   */
  public byte[] createType3Message(byte[] type2Material) throws IOException, GeneralSecurityException {
    Type2Message type2Message = new Type2Message(type2Material);
    if (type2Message.getChallenge() == null) {
      type2Message.setChallenge(new byte[0]);
    }
    Type3Message type3Message = new Type3Message(type2Message, null, lmHash(password), password, domain, username,
                                                 workstation, type2Message.getFlags(), false);
    return type3Message.toByteArray();
  }

  protected String secondChallenge(String wwwAuthenticateHeader) throws GeneralSecurityException, IOException {
    if (wwwAuthenticateHeader == null) {
      return null;
    }

    byte[] type2Material = Base64.getDecoder().decode(wwwAuthenticateHeader.substring(5));
    return createHeaderValue(createType3Message(type2Material));
  }

  private boolean mustSendType1(String wwwAuthenticateHeader) {
    // Server sent a "WWW-Authenticate: NTLM" header.
    return STARTING_NTLM_WWW_AUTHENTICATE_HEADER.equals(wwwAuthenticateHeader.trim());
  }

  private boolean mustSendType3(String wwwAuthenticateHeader) {
    // Server sent a "WWW-Authenticate: NTLM <something, we assume a type 2 challenge>" header.
    // We don't check if the "something" is a type 2 message here because in that case the createType3Message will throw
    // an exception.
    return wwwAuthenticateHeader.startsWith(NTLM_MESSAGES_PREFIX);
  }

  private String createHeaderValue(byte[] rawNtlmMessageContent) {
    // Transform a raw NTLM message to a valid "Authorization" header value.
    return NTLM_MESSAGES_PREFIX + Base64.getEncoder().encodeToString(rawNtlmMessageContent);
  }

  @Override
  public boolean hasFinished() {
    return Status.FINISHED == status;
  }

  @Override
  public String getNextHeader(String wwwAuthenticateHeader) throws Exception {
    if (wwwAuthenticateHeader == null) {
      return null;
    }

    String authHeader = null;
    if (status == Status.NOT_STARTED) {
      if (mustSendType1(wwwAuthenticateHeader)) {
        // The server told us that we have to send a first challenge message to start authentication.
        authHeader = createHeaderValue(createType1Message());
      }
      status = Status.WAITING_FOR_CHALLENGE;
    } else if (status == Status.WAITING_FOR_CHALLENGE) {
      if (mustSendType3(wwwAuthenticateHeader)) {
        // The server requests a second challenge message
        authHeader = secondChallenge(wwwAuthenticateHeader);
      }
      status = Status.FINISHED;
    }

    return authHeader;
  }

  // This holder is the alternative to the classic double-checked locking, but it's compliant with the linter rules.
  private static final class RawType1MessageHolder {

    private static final byte[] RAW_TYPE_1_MESSAGE = doCalculateType1Message();

    private static byte[] doCalculateType1Message() {
      Type1Message type1Message = new Type1Message(TYPE_1_MESSAGE_FLAGS);
      return type1Message.toByteArray();
    }
  }

  // Copied from AHC
  // TODO: NTLMv1 uses DES. Should we update it to use NTLMv2?
  private static byte[] lmHash(final String password) {
    try {
      final byte[] oemPassword = password.toUpperCase(Locale.ROOT).getBytes(US_ASCII);
      final int length = Math.min(oemPassword.length, 14);
      final byte[] keyBytes = new byte[14];
      System.arraycopy(oemPassword, 0, keyBytes, 0, length);
      final Key lowKey = createDESKey(keyBytes, 0);
      final Key highKey = createDESKey(keyBytes, 7);
      final Cipher des = Cipher.getInstance("DES/ECB/NoPadding");
      des.init(Cipher.ENCRYPT_MODE, lowKey);
      final byte[] lowHash = des.doFinal(MAGIC_CONSTANT);
      des.init(Cipher.ENCRYPT_MODE, highKey);
      final byte[] highHash = des.doFinal(MAGIC_CONSTANT);
      final byte[] lmHash = new byte[16];
      System.arraycopy(lowHash, 0, lmHash, 0, 8);
      System.arraycopy(highHash, 0, lmHash, 8, 8);
      return lmHash;
    } catch (final Exception e) {
      LOGGER.warn("Error found while calculating the NTLM password hash. Delegating the hashing to JCIFS default cipher", e);
      return null;
    }
  }

  // Copied from AHC
  private static Key createDESKey(final byte[] bytes, final int offset) {
    final byte[] keyBytes = new byte[7];
    System.arraycopy(bytes, offset, keyBytes, 0, 7);
    final byte[] material = new byte[8];
    material[0] = keyBytes[0];
    material[1] = (byte) (keyBytes[0] << 7 | (keyBytes[1] & 0xff) >>> 1);
    material[2] = (byte) (keyBytes[1] << 6 | (keyBytes[2] & 0xff) >>> 2);
    material[3] = (byte) (keyBytes[2] << 5 | (keyBytes[3] & 0xff) >>> 3);
    material[4] = (byte) (keyBytes[3] << 4 | (keyBytes[4] & 0xff) >>> 4);
    material[5] = (byte) (keyBytes[4] << 3 | (keyBytes[5] & 0xff) >>> 5);
    material[6] = (byte) (keyBytes[5] << 2 | (keyBytes[6] & 0xff) >>> 6);
    material[7] = (byte) (keyBytes[6] << 1);
    oddParity(material);
    return new SecretKeySpec(material, "DES");
  }

  // Copied from AHC
  private static void oddParity(final byte[] bytes) {
    for (int i = 0; i < bytes.length; i++) {
      final byte b = bytes[i];
      final boolean needsParity =
          (((b >>> 7) ^ (b >>> 6) ^ (b >>> 5) ^ (b >>> 4) ^ (b >>> 3) ^ (b >>> 2) ^ (b >>> 1)) & 0x01) == 0;
      if (needsParity) {
        bytes[i] |= (byte) 0x01;
      } else {
        bytes[i] &= (byte) 0xfe;
      }
    }
  }

}

