/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.server;

import static org.mule.runtime.http.api.domain.HttpProtocol.HTTP_1_1;
import static org.mule.service.http.netty.impl.streaming.StreamingEntitySender.ENTITY_STREAMING_BUFFER_SIZE;
import static org.mule.tck.junit4.matcher.Eventually.eventually;
import static org.mule.tck.junit4.matcher.FunctionExpressionMatcher.expressionMatches;
import static org.mule.tck.probe.PollingProber.probe;

import static io.netty.buffer.ByteBufAllocator.DEFAULT;
import static io.netty.handler.codec.http.LastHttpContent.EMPTY_LAST_CONTENT;
import static org.apache.commons.lang3.RandomStringUtils.secure;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.server.async.ResponseStatusCallback;
import org.mule.service.http.netty.impl.server.Http1Writer;
import org.mule.service.http.netty.impl.server.StreamingResponseSender;
import org.mule.service.http.test.common.AbstractHttpTestCase;
import org.mule.service.http.test.netty.tck.ExecutorRule;
import org.mule.tck.junit4.matcher.Eventually;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

public class StreamingResponseSenderTestCase extends AbstractHttpTestCase {

  @ClassRule
  public static ExecutorRule executorRule = new ExecutorRule();

  @Rule
  public MockitoRule rule = MockitoJUnit.rule();

  @Mock
  private ChannelHandlerContext ctx;

  @Mock
  private HttpRequest request;

  @Mock
  private HttpResponse response;

  @Mock
  private Channel channel;

  @Mock
  private ResponseStatusCallback statusCallback;

  // Each test has to set this variable.
  private HttpEntity entity;

  private final List<HttpObject> writtenObjects = new ArrayList<>();

  @Before
  public void setUp() {
    DefaultEventLoop executor = new DefaultEventLoop();
    when(channel.eventLoop()).thenReturn(executor);

    when(request.getProtocol()).thenReturn(HTTP_1_1);
    when(request.getMethod()).thenReturn("GET");

    when(response.getEntity()).thenAnswer(inv -> entity);
    when(response.getStatusCode()).thenReturn(200);
    when(response.getReasonPhrase()).thenReturn("OK");
    setUpHeaders(response);

    when(ctx.channel()).thenReturn(channel);
    when(ctx.alloc()).thenReturn(DEFAULT);
    when(ctx.newPromise()).thenAnswer(inv -> new DefaultChannelPromise(channel));
    saveDataWrittenToChannel(ctx);
  }

  @Test
  public void createSenderWithNonStreamingEntityFails() {
    entity = mock(HttpEntity.class);
    when(entity.isStreaming()).thenReturn(false);

    ExecutorService executor = executorRule.getExecutor();
    IllegalArgumentException thrown =
        assertThrows(IllegalArgumentException.class,
                     () -> new StreamingResponseSender(request, ctx, response, statusCallback, executor, new Http1Writer(ctx)));
    assertThat(thrown.getMessage(), is("Response entity must be streaming to use a StreamingResponseSender"));
  }

  @Test
  public void sendInSeveralPartsWithoutContentLength() throws IOException {
    int bufferSize = ENTITY_STREAMING_BUFFER_SIZE;
    int extraChunkSize = 10;
    // When we don't specify the entity content-length, then the buffer size is the default.
    entity = new InputStreamHttpEntity(new ByteArrayInputStream(secure().nextAlphanumeric(2 * bufferSize + extraChunkSize)
        .getBytes()));
    new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx)).send();

    probe(() -> {
      assertThat(writtenObjects.size(), is(4));
      assertThat(((DefaultHttpContent) writtenObjects.get(0)).content().readableBytes(), is(bufferSize));
      assertThat(((DefaultHttpContent) writtenObjects.get(1)).content().readableBytes(), is(bufferSize));
      assertThat(((DefaultHttpContent) writtenObjects.get(2)).content().readableBytes(), is(extraChunkSize));
      assertThat(writtenObjects.get(3), is(EMPTY_LAST_CONTENT));
      return true;
    });
  }

  @Test
  public void sendStreamWithContentLengthShorterThanDefault() throws IOException {
    int contentLength = ENTITY_STREAMING_BUFFER_SIZE / 2;
    entity =
        new InputStreamHttpEntity(new ByteArrayInputStream(secure().nextAlphanumeric(contentLength).getBytes()), contentLength);
    new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx)).send();

    probe(() -> {
      assertThat(writtenObjects.size(), is(1));
      assertThat(((DefaultHttpContent) writtenObjects.get(0)).content().readableBytes(), is(contentLength));
      return true;
    });
  }

  @Test
  public void sendStreamWithContentLengthLongerThanDefault() throws IOException {
    int bufferSize = ENTITY_STREAMING_BUFFER_SIZE;
    int extraLength = 20;
    int contentLength = bufferSize + extraLength;
    entity =
        new InputStreamHttpEntity(new ByteArrayInputStream(secure().nextAlphanumeric(contentLength).getBytes()), contentLength);
    new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx)).send();

    probe(() -> {
      assertThat(writtenObjects.size(), is(2));
      assertThat(((DefaultHttpContent) writtenObjects.get(0)).content().readableBytes(), is(bufferSize));
      assertThat(((DefaultHttpContent) writtenObjects.get(1)).content().readableBytes(), is(extraLength));
      return true;
    });
  }

  @Test
  public void sendEmptyStreamingEntity() throws IOException {
    entity = new InputStreamHttpEntity(new ByteArrayInputStream(new byte[0]), 0);
    new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx)).send();

    probe(() -> {
      assertThat(writtenObjects.size(), is(1));
      assertThat(writtenObjects.get(0), is(EMPTY_LAST_CONTENT));
      return true;
    });
  }

  @Test
  public void sendStreamThatFailsToReadWithClosedState() throws IOException {
    entity = new InputStreamHttpEntity(streamFailingWith(new IllegalStateException("Buffer is closed")));
    new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx)).send();
    probe(() -> {
      assertThat(writtenObjects.size(), is(1));
      assertThat(writtenObjects.get(0), is(EMPTY_LAST_CONTENT));
      return true;
    });
  }

  @Test
  public void sendStreamThatFailsToReadWithOtherException() {
    RuntimeException expected = new IllegalStateException("Some unexpected error");
    entity = new InputStreamHttpEntity(streamFailingWith(expected));

    var sender =
        new StreamingResponseSender(request, ctx, response, statusCallback, executorRule.getExecutor(), new Http1Writer(ctx));
    RuntimeException thrown = assertThrows(RuntimeException.class, sender::send);
    assertThat(thrown.getMessage(), is(expected.getMessage()));
  }

  @Test
  public void whenSendingSlowStream_itIsEventuallyScheduledToIoExecutor() throws IOException {
    AtomicBoolean executeWasCalled = new AtomicBoolean(false);

    // A payload that will be consumed slowly...
    String testPayload = "Test payload";
    entity = new InputStreamHttpEntity(new SlowStringStream(testPayload));

    // Can't use Mockito#spy() because of some final fields, so spying it "by hand".
    ExecutorService ioExecutor = mock(ExecutorService.class);
    doAnswer(invocation -> {
      executorRule.getExecutor().execute(invocation.getArgument(0));
      executeWasCalled.set(true);
      return null;
    }).when(ioExecutor).execute(any(Runnable.class));
    var sender = new StreamingResponseSender(request, ctx, response, statusCallback, ioExecutor, new Http1Writer(ctx));

    // As the payload is consumed slowly, the send method will eventually delegate the consumption to the io executor.
    sender.send();
    assertThat(executeWasCalled, eventuallyBecomes(true));

    // And the payload is completely written into the ctx.
    assertThat(writtenObjects, eventuallyHasSize(testPayload.length()));
    String fullContent = new String(concatContents(writtenObjects));
    assertThat(fullContent, is(testPayload));
  }

  @Test
  public void whenSendingSlowStreamButTheSchedulerIsBusy_thenAllThePayloadIsWrittenWithoutScheduling() throws IOException {
    // A payload that will be consumed slowly...
    String testPayload = "Test payload";
    entity = new InputStreamHttpEntity(new SlowStringStream(testPayload));

    // The executor is too busy, it will raise exception.
    ExecutorService ioExecutor = mock(ExecutorService.class);
    doThrow(RejectedExecutionException.class).when(ioExecutor).execute(any(Runnable.class));
    var sender = new StreamingResponseSender(request, ctx, response, statusCallback, ioExecutor, new Http1Writer(ctx));

    // As the payload is consumed slowly, the send method will eventually delegate the consumption to the io executor.
    sender.send();

    // And the payload is completely written into the ctx.
    assertThat(writtenObjects, eventuallyHasSize(testPayload.length()));
    String fullContent = new String(concatContents(writtenObjects));
    assertThat(fullContent, is(testPayload));
  }

  private byte[] concatContents(List<HttpObject> writtenObjects) throws IOException {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    for (HttpObject writtenObject : writtenObjects) {
      if (writtenObject instanceof HttpContent content) {
        byte[] bytes = ByteBufUtil.getBytes(content.content());
        baos.write(bytes);
      }
    }
    return baos.toByteArray();
  }

  private static Eventually<AtomicBoolean> eventuallyBecomes(boolean value) {
    return eventually(expressionMatches(AtomicBoolean::get, is(value)));
  }

  private static Eventually<Collection<?>> eventuallyHasSize(int expectedSize) {
    return eventually(expressionMatches(Collection::size, greaterThanOrEqualTo(expectedSize)));
  }

  private static class SlowStringStream extends InputStream {

    private final byte[] bytes;
    private int pos = 0;

    SlowStringStream(String string) {
      this.bytes = string.getBytes();
    }

    @Override
    public int read() throws IOException {
      byte[] b = new byte[1];
      int res = read(b);
      if (-1 == res) {
        return -1;
      } else {
        return b[0] & 0xFF;
      }
    }

    @Override
    public int read(byte[] buf) throws IOException {
      try {
        Thread.sleep(70);
      } catch (InterruptedException e) {
        throw new IOException(e);
      }
      if (pos >= bytes.length) {
        return -1;
      } else {
        buf[0] = bytes[pos++];
        return 1;
      }
    }
  }

  private void setUpHeaders(HttpResponse response) {
    MultiMap<String, String> headers = new MultiMap<>();
    when(response.getHeaders()).thenReturn(headers);
  }

  private void saveDataWrittenToChannel(ChannelHandlerContext context) {
    when(context.writeAndFlush(any(HttpObject.class), any(ChannelPromise.class))).thenAnswer(invocation -> {
      HttpObject toWrite = invocation.getArgument(0, HttpObject.class);
      writtenObjects.add(toWrite);
      ChannelPromise promise = invocation.getArgument(1, ChannelPromise.class);
      promise.setSuccess().get();
      return null;
    });
  }

  private InputStream streamFailingWith(RuntimeException exception) {
    return new InputStream() {

      @Override
      public int read() {
        throw exception;
      }
    };
  }
}
