/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.auth.ntlm.message;

import static org.mule.service.http.netty.impl.client.auth.ntlm.NtlmFlags.NTLMSSP_NEGOTIATE_NTLM;
import static org.mule.service.http.netty.impl.client.auth.ntlm.NtlmFlags.NTLMSSP_NEGOTIATE_OEM;
import static org.mule.service.http.netty.impl.client.auth.ntlm.NtlmFlags.NTLMSSP_NEGOTIATE_UNICODE;
import static org.mule.service.http.netty.impl.client.auth.ntlm.NtlmFlags.NTLMSSP_NEGOTIATE_VERSION;
import static org.mule.service.http.netty.impl.client.auth.ntlm.smb.SmbConstants.MILLISECONDS_BETWEEN_1970_AND_1601;

import org.mule.service.http.netty.impl.client.auth.ntlm.NtlmFlags;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvFlags;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPairs;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvSingleHost;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvTargetName;
import org.mule.service.http.netty.impl.client.auth.ntlm.av.AvTimestamp;
import org.mule.service.http.netty.impl.client.auth.ntlm.util.Crypto;
import org.mule.service.http.netty.impl.client.auth.ntlm.smb.NtlmUtil;

import javax.crypto.Cipher;
import java.io.UnsupportedEncodingException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.LinkedList;
import java.util.List;

/**
 * Represents a Type 3 NTLM message used in NTLM authentication.
 *
 * <p>
 * This implementation is based on the jcifs library, available at <a href="https://github.com/kevinboone/jcifs">jcifs GitHub
 * Repository</a>.
 * </p>
 */
public class Type3Message extends NtlmMessage {

  private byte[] masterKey;
  private String domain;
  private String user;
  private String workstation;
  private byte[] lmResponse;
  private byte[] ntResponse;

  private byte[] mic = null;
  private boolean micRequired;
  private byte[] sessionKey;

  public Type3Message(Type2Message type2, String targetName, byte[] passwordHash, String password, String domain, String user,
                      String workstation, int flags, boolean nonAnonymous)
      throws GeneralSecurityException {
    setFlags(flags | getDefaultFlags(type2));
    setWorkstation(workstation);
    setDomain(domain);
    setUser(user);

    SecureRandom secureRandom = new SecureRandom();

    if ((password == null && passwordHash == null) || (!nonAnonymous && (password != null && password.length() == 0))) {
      setLMResponse(null);
      setNTResponse(null);
      return;
    }

    if (passwordHash == null) {
      passwordHash = NtlmUtil.getNTHash(password);
    }

    byte[] ntlmClientChallengeInfo = type2.getTargetInformation();
    List<AvPair> avPairs = ntlmClientChallengeInfo != null ? AvPairs.decode(ntlmClientChallengeInfo) : null;

    // if targetInfo has an MsvAvTimestamp
    // client should not send LmChallengeResponse
    boolean haveTimestamp = AvPairs.contains(avPairs, AvPair.MsvAvTimestamp);
    if (!haveTimestamp) {
      byte[] lmClientChallenge = new byte[8];
      secureRandom.nextBytes(lmClientChallenge);
      setLMResponse(getLMv2Response(type2, domain, user, passwordHash, lmClientChallenge));
    } else {
      setLMResponse(new byte[24]);
    }

    if (avPairs != null) {
      // make sure to set the TARGET_INFO flag as we are sending
      setFlag(NtlmFlags.NTLMSSP_NEGOTIATE_TARGET_INFO, true);
    }

    byte[] responseKeyNT = NtlmUtil.nTOWFv2(domain, user, passwordHash);
    byte[] ntlmClientChallenge = new byte[8];
    secureRandom.nextBytes(ntlmClientChallenge);

    long ts = (System.currentTimeMillis() + MILLISECONDS_BETWEEN_1970_AND_1601) * 10000;
    if (haveTimestamp) {
      ts = ((AvTimestamp) AvPairs.get(avPairs, AvPair.MsvAvTimestamp)).getTimestamp();
    }

    setNTResponse(
                  getNTLMv2Response(type2, responseKeyNT, ntlmClientChallenge,
                                    makeAvPairs(targetName, avPairs, haveTimestamp, ts), ts));

    MessageDigest hmac = Crypto.getHMACT64(responseKeyNT);
    hmac.update(this.ntResponse, 0, 16); // only first 16 bytes of ntResponse
    byte[] userSessionKey = hmac.digest();

    if (getFlag(NtlmFlags.NTLMSSP_NEGOTIATE_KEY_EXCH)) {
      this.masterKey = new byte[16];
      secureRandom.nextBytes(this.masterKey);

      byte[] encryptedKey = new byte[16];
      Cipher rc4 = Crypto.getArcfour(userSessionKey);
      rc4.update(this.masterKey, 0, 16, encryptedKey, 0);
      setEncryptedSessionKey(encryptedKey);
    } else {
      this.masterKey = userSessionKey;
    }

  }

  public static byte[] getLMv2Response(Type2Message type2, String domain, String user, byte[] passwordHash,
                                       byte[] clientChallenge)
      throws GeneralSecurityException {
    if (type2 == null || domain == null || user == null || passwordHash == null || clientChallenge == null) {
      return null;
    }
    return NtlmUtil.getLMv2Response(domain, user, passwordHash, type2.getChallenge(), clientChallenge);
  }

  private int getDefaultFlags(Type2Message type2) {
    int flags = NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_NEGOTIATE_VERSION;
    flags |= type2.getFlag(NTLMSSP_NEGOTIATE_UNICODE) ? NTLMSSP_NEGOTIATE_UNICODE : NTLMSSP_NEGOTIATE_OEM;
    return flags;
  }

  /**
   * Returns the domain in which the user has an account.
   *
   * @return A <code>String</code> containing the domain for the user.
   */
  public String getDomain() {
    return this.domain;
  }


  /**
   * Sets the domain for this message.
   *
   * @param domain The domain.
   */
  public void setDomain(String domain) {
    this.domain = domain;
  }


  /**
   * Returns the username for the authenticating user.
   *
   * @return A <code>String</code> containing the user for this message.
   */
  public String getUser() {
    return this.user;
  }


  /**
   * Sets the user for this message.
   *
   * @param user The user.
   */
  public void setUser(String user) {
    this.user = user;
  }


  /**
   * Returns the workstation from which authentication is being performed.
   *
   * @return A <code>String</code> containing the workstation.
   */
  public String getWorkstation() {
    return this.workstation;
  }


  /**
   * Sets the workstation for this message.
   *
   * @param workstation The workstation.
   */
  public void setWorkstation(String workstation) {
    this.workstation = workstation;
  }

  /**
   * Returns the LanManager/LMv2 response.
   *
   * @return A <code>byte[]</code> containing the LanManager response.
   */
  public byte[] getLMResponse() {
    return this.lmResponse;
  }


  /**
   * Sets the LanManager/LMv2 response for this message.
   *
   * @param lmResponse The LanManager response.
   */
  public void setLMResponse(byte[] lmResponse) {
    this.lmResponse = lmResponse;
  }


  /**
   * Returns the NT/NTLMv2 response.
   *
   * @return A <code>byte[]</code> containing the NT/NTLMv2 response.
   */
  public byte[] getNTResponse() {
    return this.ntResponse;
  }


  /**
   * Sets the NT/NTLMv2 response for this message.
   *
   * @param ntResponse The NT/NTLMv2 response.
   */
  public void setNTResponse(byte[] ntResponse) {
    this.ntResponse = ntResponse;
  }

  public static byte[] getNTLMv2Response(Type2Message type2, byte[] responseKeyNT, byte[] clientChallenge,
                                         byte[] clientChallengeInfo, long ts)
      throws NoSuchAlgorithmException {
    if (type2 == null || responseKeyNT == null || clientChallenge == null) {
      return null;
    }
    return NtlmUtil.getNTLMv2Response(responseKeyNT, type2.getChallenge(), clientChallenge, ts, clientChallengeInfo);
  }

  private byte[] makeAvPairs(String targetName, List<AvPair> serverAvPairs, boolean haveServerTimestamp, long ts) {
    if (serverAvPairs == null) {
      serverAvPairs = new LinkedList<>();
    }

    if (getFlag(NtlmFlags.NTLMSSP_NEGOTIATE_SIGN)
        && (haveServerTimestamp)) {
      // should provide MIC
      this.micRequired = true;
      this.mic = new byte[16];
      int curFlags = 0;
      AvFlags cur = (AvFlags) AvPairs.get(serverAvPairs, AvPair.MsvAvFlags);
      if (cur != null) {
        curFlags = cur.getFlags();
      }
      curFlags |= 0x2; // MAC present
      AvPairs.replace(serverAvPairs, new AvFlags(curFlags));
    }

    AvPairs.replace(serverAvPairs, new AvTimestamp(ts));

    if (targetName != null) {
      AvPairs.replace(serverAvPairs, new AvTargetName(targetName));
    }

    // possibly add channel bindings
    AvPairs.replace(serverAvPairs, new AvPair(0xa, new byte[16]));
    AvPairs.replace(serverAvPairs, new AvSingleHost(getMachineId()));

    return AvPairs.encode(serverAvPairs);
  }

  private byte[] getMachineId() {
    byte[] mid = new byte[32];
    new SecureRandom().nextBytes(mid);
    return mid;
  }

  /**
   * Sets the session key.
   *
   * @param sessionKey The session key.
   */
  public void setEncryptedSessionKey(byte[] sessionKey) {
    this.sessionKey = sessionKey;
  }

  public byte[] toByteArray() throws UnsupportedEncodingException {
    int size = 64;
    boolean unicode = getFlag(NTLMSSP_NEGOTIATE_UNICODE);
    String oemCp = unicode ? null : getOEMEncoding();

    String domainName = getDomain();
    byte[] domainBytes = null;
    if (domainName != null && domainName.length() != 0) {
      domainBytes = unicode ? domainName.getBytes(UNI_ENCODING) : domainName.getBytes(oemCp);
      size += domainBytes.length;
    }

    String userName = getUser();
    byte[] userBytes = null;
    if (userName != null && userName.length() != 0) {
      userBytes = unicode ? userName.getBytes(UNI_ENCODING) : userName.toUpperCase().getBytes(oemCp);
      size += userBytes.length;
    }

    String workstationName = getWorkstation();
    byte[] workstationBytes = null;
    if (workstationName != null && workstationName.length() != 0) {
      workstationBytes = unicode ? workstationName.getBytes(UNI_ENCODING) : workstationName.toUpperCase().getBytes(oemCp);
      size += workstationBytes.length;
    }

    byte[] micBytes = getMic();
    if (micBytes != null) {
      size += 8 + 16;
    } else if (getFlag(NTLMSSP_NEGOTIATE_VERSION)) {
      size += 8;
    }

    byte[] lmResponseBytes = getLMResponse();
    size += (lmResponseBytes != null) ? lmResponseBytes.length : 0;

    byte[] ntResponseBytes = getNTResponse();
    size += (ntResponseBytes != null) ? ntResponseBytes.length : 0;

    byte[] sessionKeyBytes = getEncryptedSessionKey();
    size += (sessionKeyBytes != null) ? sessionKeyBytes.length : 0;

    byte[] type3 = new byte[size];
    int pos = 0;

    System.arraycopy(NTLMSSP_SIGNATURE, 0, type3, 0, 8);
    pos += 8;

    writeULong(type3, pos, NTLMSSP_TYPE3);
    pos += 4;

    int lmOff = writeSecurityBuffer(type3, 12, lmResponseBytes);
    pos += 8;
    int ntOff = writeSecurityBuffer(type3, 20, ntResponseBytes);
    pos += 8;
    int domOff = writeSecurityBuffer(type3, 28, domainBytes);
    pos += 8;
    int userOff = writeSecurityBuffer(type3, 36, userBytes);
    pos += 8;
    int wsOff = writeSecurityBuffer(type3, 44, workstationBytes);
    pos += 8;
    int skOff = writeSecurityBuffer(type3, 52, sessionKeyBytes);
    pos += 8;

    writeULong(type3, pos, getFlags());
    pos += 4;

    if (getFlag(NTLMSSP_NEGOTIATE_VERSION)) {
      System.arraycopy(NTLMSSP_VERSION, 0, type3, pos, NTLMSSP_VERSION.length);
      pos += NTLMSSP_VERSION.length;
    } else if (micBytes != null) {
      pos += NTLMSSP_VERSION.length;
    }

    if (micBytes != null) {
      System.arraycopy(micBytes, 0, type3, pos, 16);
      pos += 16;
    }

    pos += writeSecurityBufferContent(type3, pos, lmOff, lmResponseBytes);
    pos += writeSecurityBufferContent(type3, pos, ntOff, ntResponseBytes);
    pos += writeSecurityBufferContent(type3, pos, domOff, domainBytes);
    pos += writeSecurityBufferContent(type3, pos, userOff, userBytes);
    pos += writeSecurityBufferContent(type3, pos, wsOff, workstationBytes);
    writeSecurityBufferContent(type3, pos, skOff, sessionKeyBytes);

    return type3;
  }

  private byte[] getEncryptedSessionKey() {
    return sessionKey;
  }

  private byte[] getMic() {
    return mic;
  }
}
