/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.util;

import static org.mule.service.http.test.netty.AllureConstants.HttpStory.STREAMING;
import static org.mule.tck.probe.PollingProber.probe;

import static java.lang.Thread.sleep;
import static java.nio.charset.StandardCharsets.UTF_8;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThrows;
import static org.junit.internal.matchers.ThrowableCauseMatcher.hasCause;
import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.service.http.netty.impl.streaming.BlockingBidirectionalStream;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicBoolean;

import io.qameta.allure.Issue;
import io.qameta.allure.Story;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;

@Story(STREAMING)
class BlockingBufferTestCase {

  private static final String TEST_PAYLOAD =
      "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";

  private static final Logger LOGGER = getLogger(BlockingBufferTestCase.class);

  private final BlockingBidirectionalStream blockingBuffer = new BlockingBidirectionalStream();
  private TestConsumer consumer;

  private void startConsumer() {
    consumer = new TestConsumer(blockingBuffer.getInputStream(), 8);
    consumer.start();
  }

  @AfterEach
  void tearDown() throws InterruptedException {
    if (consumer != null) {
      consumer.join();
    }
  }

  @Test
  @Issue("W-19584546")
  void readShouldReturnValuesBetween_0_and_255() throws Exception {
    var bidiStream = new BlockingBidirectionalStream();

    // Decimals 139 and -117 are both binary 10001011. An InputStream should return always a number
    // in range [0,255] if a byte is read, and -1 on EOS.
    bidiStream.write(new byte[] {(byte) 139, (byte) -117}, 0, 2);
    bidiStream.close();

    assertThat(bidiStream.read(), is(139));
    assertThat(bidiStream.read(), is(139));
    assertThat(bidiStream.read(), is(-1));
  }

  @Test
  void consumerBlocksWhenBufferIsEmpty() throws InterruptedException {
    startConsumer();

    // The test has to work regardless of this sleep(), but sleeping should evidence
    // the bug if it exists.
    sleep(500);

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    consumer.interrupt();

    probe(() -> {
      Exception error = consumer.getErrorWhileConsuming();
      assertThat(error, instanceOf(IOException.class));
      assertThat(error.getCause(), instanceOf(InterruptedException.class));
      return true;
    });
  }

  @Test
  void consumerUnblocksWhenBufferIsClosedAndEmpty() throws InterruptedException, IOException {
    startConsumer();

    // The test has to work regardless of this sleep(), but sleeping should evidence
    // the bug if it exists.
    sleep(500);

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    blockingBuffer.getOutputStream().close();

    probe(() -> {
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(nullValue()));
      assertThat(consumer.getConsumedData().length, is(0));
      return true;
    });
  }

  @Test
  void cancellationErrorIsPropagatedToReader() {
    startConsumer();

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    RuntimeException cancellationException = new RuntimeException("Expected!!");
    blockingBuffer.getOutputStream().cancel(cancellationException);

    probe(() -> {
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getConsumedData().length, is(0));
      assertThat(consumer.getErrorWhileConsuming(), allOf(instanceOf(IOException.class),
                                                          hasMessage(is("Streaming canceled by writer")),
                                                          hasCause(is(cancellationException))));
      return true;
    });
  }

  @Test
  void cancellationErrorPushesBackTheNextWrite() {
    startConsumer();

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    RuntimeException cancellationException = new RuntimeException("Expected!!");
    blockingBuffer.getOutputStream().cancel(cancellationException);

    var error = assertThrows(IOException.class, () -> blockingBuffer.getOutputStream().write(TEST_PAYLOAD.getBytes(UTF_8)));
    assertThat(error, allOf(instanceOf(IOException.class),
                            hasMessage(is("Trying to write in a canceled stream")),
                            hasCause(is(cancellationException))));
  }

  @Test
  void writeAndReadAPayloadWithDifferentChunkSizes() throws IOException {
    startConsumer();

    String[] words = TEST_PAYLOAD.split(" ");
    for (String word : words) {
      // Each word will have its length.
      blockingBuffer.getOutputStream().write(word.getBytes(UTF_8));
    }
    blockingBuffer.getOutputStream().close();

    // As we split by space, the expected consumed data doesn't contain them.
    final String testPayloadWithoutSpaces = TEST_PAYLOAD.replace(" ", "");
    probe(() -> {
      assertThat(new String(consumer.getConsumedData()), is(testPayloadWithoutSpaces));
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(nullValue()));
      return true;
    });
  }

  @Test
  void writeBytePerByte() throws IOException {
    startConsumer();

    for (byte b : TEST_PAYLOAD.getBytes(UTF_8)) {
      blockingBuffer.getOutputStream().write(b);
    }
    blockingBuffer.getOutputStream().close();

    probe(() -> {
      assertThat(new String(consumer.getConsumedData()), is(TEST_PAYLOAD));
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(nullValue()));
      return true;
    });
  }

  @Test
  void writeLessBytesThanBufferSize() throws IOException {
    startConsumer();

    byte[] abc = "abc".getBytes(UTF_8);

    blockingBuffer.getOutputStream().write(abc);

    probe(() -> {
      // The data is read and the consumer is blocked waiting for another chunk.
      assertThat(new String(consumer.getConsumedData()), is("abc"));
      assertThat(consumer.finishedReading(), is(false));
      return true;
    });

    blockingBuffer.getOutputStream().close();

    probe(() -> {
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(nullValue()));
      return true;
    });
  }

  private static class TestConsumer extends Thread {

    private final InputStream inputStream;
    private final int consumerBufferSize;
    private final ByteArrayOutputStream consumedData;
    private final AtomicBoolean continueReading;
    private Exception errorWhileConsuming;

    public TestConsumer(InputStream inputStream, int consumerBufferSize) {
      this.inputStream = inputStream;
      this.consumerBufferSize = consumerBufferSize;
      this.consumedData = new ByteArrayOutputStream();
      this.continueReading = new AtomicBoolean(true);
    }

    public synchronized byte[] getConsumedData() {
      return consumedData.toByteArray();
    }

    public boolean finishedReading() {
      return !continueReading.get();
    }

    public Exception getErrorWhileConsuming() {
      return errorWhileConsuming;
    }

    @Override
    public void run() {
      while (continueReading.get()) {
        byte[] consumerBuffer = new byte[consumerBufferSize];
        try {
          int bytesRead = inputStream.read(consumerBuffer, 0, consumerBufferSize);
          if (bytesRead == -1 || bytesRead == 0) {
            continueReading.set(false);
          } else {
            synchronized (this) {
              LOGGER.debug("Reading this chunk [{}]", new String(consumerBuffer, 0, bytesRead));
              consumedData.write(consumerBuffer, 0, bytesRead);
            }
          }
        } catch (IOException e) {
          LOGGER.debug("Found error while consuming", e);
          this.errorWhileConsuming = e;
          continueReading.set(false);
        }
      }
    }
  }
}
