/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package org.mule.soapkit.soap.server;

import com.google.common.collect.ImmutableMap;
import org.apache.cxf.annotations.SchemaValidation.SchemaValidationType;
import org.apache.cxf.attachment.AttachmentImpl;
import org.apache.cxf.binding.soap.SoapBindingConstants;
import org.apache.cxf.common.util.UrlUtils;
import org.apache.cxf.endpoint.Server;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.message.Attachment;
import org.apache.cxf.message.Exchange;
import org.apache.cxf.message.ExchangeImpl;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageImpl;
import org.apache.cxf.transport.Destination;
import org.jetbrains.annotations.NotNull;
import org.mule.runtime.api.component.execution.ComponentExecutionException;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.streaming.bytes.CursorStreamProvider;
import org.mule.soap.api.message.SoapAttachment;
import org.mule.soap.api.client.BadResponseException;
import org.mule.soap.api.message.SoapRequest;
import org.mule.soapkit.soap.api.param.ValidationLevel;
import org.mule.soapkit.soap.api.server.SoapServer;
import org.mule.soapkit.soap.api.server.SoapServerHandler;
import org.mule.soapkit.soap.message.EmptySoapResponse;
import org.mule.soapkit.soap.message.ImmutableSoapResponse;
import org.mule.soapkit.soap.message.SoapResponse;
import org.mule.soapkit.soap.server.support.DelegatingOutputStream;
import org.mule.wsdl.parser.model.PortModel;

import javax.xml.namespace.QName;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.UnsupportedCharsetException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Strings.isNullOrEmpty;
import static java.lang.Boolean.TRUE;
import static java.lang.String.format;
import static java.nio.charset.Charset.defaultCharset;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.toMap;
import static org.apache.commons.lang3.StringUtils.isNotEmpty;
import static org.apache.cxf.common.util.StringUtils.isEmpty;
import static org.apache.cxf.interceptor.StaxInEndingInterceptor.STAX_IN_NOCLOSE;
import static org.apache.cxf.message.Message.CONTENT_TYPE;
import static org.apache.cxf.message.Message.ENCODING;
import static org.apache.cxf.message.Message.HTTP_REQUEST_METHOD;
import static org.apache.cxf.message.Message.PROTOCOL_HEADERS;
import static org.apache.cxf.message.Message.QUERY_STRING;
import static org.apache.cxf.message.Message.REQUEST_URL;
import static org.apache.cxf.transport.local.LocalConduit.DIRECT_DISPATCH;
import static org.mule.runtime.core.api.util.ExceptionUtils.extractOfType;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.mule.soapkit.soap.SoapConstants.MULE_HTTP_ATTRIBUTES_CONTENT_TYPE;
import static org.mule.soapkit.soap.SoapConstants.MULE_HTTP_ATTRIBUTES_LOCATION;
import static org.mule.soapkit.soap.SoapConstants.MULE_HTTP_ATTRIBUTES_METHOD;
import static org.mule.soapkit.soap.SoapConstants.MULE_HTTP_ATTRIBUTES_QUERY_STRING;
import static org.mule.soapkit.soap.SoapConstants.MULE_SERVER_HANDLER_KEY;
import static org.mule.soapkit.soap.SoapConstants.MULE_SOAP_ACTION_KEY;
import static org.mule.soapkit.soap.SoapConstants.MULE_TRANSPORT_HEADERS_PREFIX;
import static org.mule.soapkit.soap.SoapConstants.SERVER_RESPONSE_KEY;
import static org.mule.soapkit.soap.util.DataHandlerUtils.toDataHandler;

public class SoapCxfServer implements SoapServer {

  private final Server server;
  private final PortModel portModel;
  private final boolean mtomEnabled;
  private final boolean validationEnabled;
  private final ValidationLevel validationLevel;

  private static final String WSDL = "wsdl";
  private static final String XSD = "xsd";
  private static final String SOAP_ACTION = SoapBindingConstants.SOAP_ACTION;

  public static final String STATUS_CODE = "statusCode";
  public static final String HTTP_STATUS_PROTOCOL_HEADER = "httpStatus";

  public SoapCxfServer(Server server, PortModel portModel, boolean mtomEnabled, boolean validationEnabled,
                       ValidationLevel validationLevel) {
    this.server = server;
    this.portModel = portModel;
    this.mtomEnabled = mtomEnabled;
    this.validationEnabled = validationEnabled;
    this.validationLevel = validationLevel;
  }

  static SoapCxfServer create(Server server, PortModel portModel, boolean mtomEnabled, boolean validationEnabled,
                              ValidationLevel validationLevel) {
    return new SoapCxfServer(server, portModel, mtomEnabled, validationEnabled, validationLevel);
  }

  public SoapResponse serve(SoapRequest request, SoapServerHandler handler) {
    final Exchange exchange = sendThroughCxf(request, handler);
    return toSoapResponse(exchange, handler);
  }

  @Override
  public String getPort() {
    return portModel.getName();
  }

  @Override
  public QName getService() {
    return server.getEndpoint().getService().getName();
  }

  private Exchange sendThroughCxf(SoapRequest request, SoapServerHandler handler) {
    final Exchange exchange = new ExchangeImpl();
    final MessageImpl messageIn = new MessageImpl();
    messageIn.setExchange(exchange);

    if (validationEnabled) {
      messageIn.put(Message.REQUESTOR_ROLE, Boolean.FALSE);
      messageIn.put(Message.SCHEMA_VALIDATION_ENABLED, SchemaValidationType.REQUEST);
    }

    exchange.put(MULE_SERVER_HANDLER_KEY, handler);

    // Method
    final Map<String, String> transportHeaders = request.getTransportHeaders();
    final String method = transportHeaders.get(MULE_HTTP_ATTRIBUTES_METHOD);
    if (isNotEmpty(method)) {
      messageIn.put(HTTP_REQUEST_METHOD, method);
    }
    final boolean isGetMethod = "GET".equalsIgnoreCase(method);

    // Content-Type
    final String contentType = transportHeaders.get(MULE_HTTP_ATTRIBUTES_CONTENT_TYPE);
    if (isNotEmpty(contentType)) {
      messageIn.put(CONTENT_TYPE, contentType);
    }

    // MTom
    boolean isMtom = mtomEnabled || contentType != null && contentType.contains("application/xop+xml");
    messageIn.put(Message.MTOM_ENABLED, isMtom ? Boolean.TRUE : Boolean.FALSE);

    // Wsdl & XSD Request Analysis
    String queryString = transportHeaders.get(MULE_HTTP_ATTRIBUTES_QUERY_STRING);
    boolean isWsdlOrXsdRequest = false;
    if (isNotEmpty(queryString)) {
      messageIn.put(QUERY_STRING, queryString);
      isWsdlOrXsdRequest = isGetMethod && isWSDLOrXSDRequest(queryString);
    }

    // Wsdl or XSD Response
    if (isWsdlOrXsdRequest) {

      // This can be null when you use SoapServer directly
      final String location = transportHeaders.get(MULE_HTTP_ATTRIBUTES_LOCATION);
      if (location != null) {
        messageIn.put(REQUEST_URL, location);
      }
    }
    // Soap Operation Response
    else {
      final Map<String, List<String>> protocolHeaders = request.getTransportHeaders().entrySet().stream()
          .filter(e -> e.getKey().startsWith(MULE_TRANSPORT_HEADERS_PREFIX))
          .collect(toMap(e -> e.getKey().replace(MULE_TRANSPORT_HEADERS_PREFIX, ""),
                         e -> isNullOrEmpty(e.getValue()) ? emptyList() : singletonList(e.getValue())));

      // SoapActionInInterceptor needs this HEADER to get BindingOperationInfo from SOAPAction
      final Optional<String> soapAction = getSoapAction(request.getTransportHeaders());
      soapAction.ifPresent(action -> {
        protocolHeaders.remove(SOAP_ACTION.toLowerCase());
        protocolHeaders.put(SOAP_ACTION, singletonList(action));
        messageIn.put(SOAP_ACTION, action);
      });
      messageIn.put(PROTOCOL_HEADERS, protocolHeaders);

      // Attachments
      final Map<String, Attachment> attachments = transformToCxfAttachments(request.getAttachments());
      messageIn.setAttachments(attachments.values());

      setContent(request, messageIn);

    }

    final Destination d = server.getDestination();

    // Set up a listener for the response
    messageIn.put(DIRECT_DISPATCH, TRUE);
    messageIn.setDestination(d);

    // Mule will close the stream so don't let cxf, otherwise cxf will close it too early
    exchange.put(STAX_IN_NOCLOSE, TRUE);
    exchange.setInMessage(messageIn);

    // invoke the actual web service up until right before we serialize the
    // response
    d.getMessageObserver().onMessage(messageIn);

    Exception exception = exchange.get(Exception.class);
    if (exception != null) {
      Optional<ComponentExecutionException> cause = extractOfType(exception, ComponentExecutionException.class);
      if (cause.isPresent()) {
        // Propagate the exception if the flow execution failed
        throw cause.get();
      }
      if (validationEnabled && ValidationLevel.ERROR.equals(validationLevel)) {
        Optional<Fault> validationException = extractOfType(exception, Fault.class);
        if (validationException.isPresent() && validationException.get().getFaultCode().equals(Fault.FAULT_CODE_CLIENT)) {
          throw validationException.get();
        }
      }
    }

    if (isWsdlOrXsdRequest && exchange.getOutMessage() != null) {
      final String contentTypeOut = (String) exchange.getOutMessage().get(CONTENT_TYPE);
      exchange.put(CONTENT_TYPE, contentTypeOut);
    }
    // get the response request
    return exchange;
  }

  private boolean isWSDLOrXSDRequest(String queryString) {
    final Map<String, String> queryStringMap = UrlUtils.parseQueryString(queryString);
    return queryStringMap.containsKey(WSDL) || queryStringMap.containsKey(XSD);
  }

  private SoapResponse toSoapResponse(Exchange exchange, SoapServerHandler handler) {
    Message outMessage = exchange.getOutMessage();
    Message outFaultMessage = exchange.getOutFaultMessage();
    InputStream is;
    MediaType contentType;

    SoapResponse soapResponse =
        ofNullable((SoapResponse) exchange.get(SERVER_RESPONSE_KEY)).orElse(new EmptySoapResponse(handler.getVariables()));

    Map<String, String> transportAdditionalData = new HashMap<>(soapResponse.getTransportAdditionalData());
    Map<String, String> transportHeaders = new HashMap<>(soapResponse.getTransportHeaders());

    Message contentMsg;
    if (outFaultMessage != null && outFaultMessage.getContent(DelegatingOutputStream.class) != null) {
      contentMsg = outFaultMessage;
      transportAdditionalData.put(STATUS_CODE, String.valueOf(INTERNAL_SERVER_ERROR.getStatusCode()));
      transportHeaders.put(HTTP_STATUS_PROTOCOL_HEADER, String.valueOf(INTERNAL_SERVER_ERROR.getStatusCode()));
    } else {
      contentMsg = outMessage;
    }

    DelegatingOutputStream response = contentMsg.getContent(DelegatingOutputStream.class);

    if (response != null) {
      final ByteArrayOutputStream os = (ByteArrayOutputStream) response.getOutputStream();
      is = new ByteArrayInputStream(os.toByteArray());
      contentType = getContentType(contentMsg);
    } else {
      is = soapResponse.getContent();
      contentType = MediaType.parse(soapResponse.getContentType());
    }

    return new ImmutableSoapResponse(is, soapResponse.getSoapHeaders(), transportHeaders,
                                     transportAdditionalData, soapResponse.getAttachments(), contentType,
                                     soapResponse.getVariables());
  }

  private MediaType getContentType(final Message contentMsg) {
    final String contentType = (String) contentMsg.get(CONTENT_TYPE);
    final String encoding = Optional.ofNullable((String) contentMsg.get(Message.ENCODING)).orElse(defaultCharset().name());

    return isNullOrEmpty(contentType) ? MediaType.XML : MediaType.parse(contentType).withCharset(getEncoding(encoding));
  }

  public Server getServer() {
    return server;
  }

  private Optional<String> getSoapAction(final Map<String, String> transportHeaders) {
    String action = transportHeaders.get(MULE_SOAP_ACTION_KEY.toLowerCase());

    if (action != null && action.startsWith("\"") && action.endsWith("\"") && action.length() >= 2) {
      action = action.substring(1, action.length() - 1);
    }

    return ofNullable(action);
  }

  // Copied from SoapCxfClient.
  private Map<String, Attachment> transformToCxfAttachments(final Map<String, SoapAttachment> attachments) {
    ImmutableMap.Builder<String, Attachment> builder = ImmutableMap.builder();
    attachments.forEach((name, value) -> {
      try {
        builder.put(name,
                    new AttachmentImpl(name, toDataHandler(name, value.getContent(), MediaType.parse(value.getContentType()))));
      } catch (IOException e) {
        throw new BadResponseException(format("Error while preparing attachment [%s]", name), e);
      }
    });
    return builder.build();
  }

  private void setContent(final SoapRequest request, final Message message) {
    final InputStream payload = request.getContent();

    message.put(ENCODING, getEncoding(request).name());
    final Object content =
        (payload instanceof CursorStreamProvider) ? ((CursorStreamProvider) payload).openCursor() : payload;
    message.setContent(InputStream.class, content);
  }

  private Charset getEncoding(final String encoding) {
    try {
      return Charset.forName(encoding);
    } catch (UnsupportedCharsetException e) {
      return Charset.defaultCharset();
    }
  }

  @NotNull
  private Charset getEncoding(final SoapRequest request) {
    final String contentType = request.getTransportHeaders().get(MULE_HTTP_ATTRIBUTES_CONTENT_TYPE);

    if (isEmpty(contentType))
      return defaultCharset();

    Charset charSet;
    try {
      charSet = MediaType.parse(contentType).getCharset().orElse(defaultCharset());
    } catch (IllegalArgumentException e) {
      charSet = defaultCharset();
    }
    return charSet;
  }

  @Override
  public void start() {
    getServer().start();
  }

  @Override
  public void stop() {
    getServer().destroy();
  }
}
