/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.netty;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.core.memory.MemorySegmentProvider;
import org.apache.flink.runtime.io.network.NetworkClientHandler;
import org.apache.flink.runtime.io.network.TestingPartitionRequestClient;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.FullyFilledBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.netty.CreditBasedPartitionRequestClientHandler;
import org.apache.flink.runtime.io.network.netty.NettyBufferPool;
import org.apache.flink.runtime.io.network.netty.NettyMessage;
import org.apache.flink.runtime.io.network.netty.NettyMessageClientDecoderDelegate;
import org.apache.flink.runtime.io.network.netty.NettyTestUtil;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.apache.flink.util.ExceptionUtils;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class NettyMessageClientDecoderDelegateTest {
    private static final int BUFFER_SIZE = 1024;
    private static final int NUMBER_OF_BUFFER_RESPONSES = 5;
    private static final NettyBufferPool ALLOCATOR = new NettyBufferPool(1);
    private EmbeddedChannel channel;
    private NetworkBufferPool networkBufferPool;
    private SingleInputGate inputGate;
    private InputChannelID inputChannelId;
    private InputChannelID releasedInputChannelId;

    NettyMessageClientDecoderDelegateTest() {
    }

    private void setup(int numOfPartialBuffers) throws IOException, InterruptedException {
        CreditBasedPartitionRequestClientHandler handler = new CreditBasedPartitionRequestClientHandler();
        this.networkBufferPool = new NetworkBufferPool(5, numOfPartialBuffers * 1024);
        this.channel = new EmbeddedChannel(new ChannelHandler[]{new NettyMessageClientDecoderDelegate((NetworkClientHandler)handler)});
        this.inputGate = InputChannelTestUtils.createSingleInputGate(1, (MemorySegmentProvider)this.networkBufferPool);
        RemoteInputChannel inputChannel = InputChannelTestUtils.createRemoteInputChannel(this.inputGate, new TestingPartitionRequestClient(), 5);
        this.inputGate.setInputChannels(new InputChannel[]{inputChannel});
        this.inputGate.setup();
        inputChannel.requestSubpartitions();
        handler.addInputChannel(inputChannel);
        this.inputChannelId = inputChannel.getInputChannelId();
        SingleInputGate releasedInputGate = InputChannelTestUtils.createSingleInputGate(1, (MemorySegmentProvider)this.networkBufferPool);
        RemoteInputChannel releasedInputChannel = new InputChannelBuilder().buildRemoteChannel(this.inputGate);
        releasedInputGate.close();
        handler.addInputChannel(releasedInputChannel);
        this.releasedInputChannelId = releasedInputChannel.getInputChannelId();
    }

    @AfterEach
    void tearDown() throws IOException {
        if (this.inputGate != null) {
            this.inputGate.close();
        }
        if (this.networkBufferPool != null) {
            this.networkBufferPool.destroyAllBufferPools();
            this.networkBufferPool.destroy();
        }
        if (this.channel != null) {
            this.channel.close();
        }
    }

    private static Stream<Arguments> bufferDescriptors() {
        return Stream.of(Arguments.of((Object[])new Object[]{false, 1}), Arguments.of((Object[])new Object[]{true, 1}), Arguments.of((Object[])new Object[]{true, 3}));
    }

    @ParameterizedTest(name="{index} => isFullyFilled={0}, numOfPartialBuffers={1}")
    @MethodSource(value={"bufferDescriptors"})
    void testClientMessageDecode(boolean isFullyFilled, int numOfPartialBuffers) throws Exception {
        this.setup(numOfPartialBuffers);
        this.testNettyMessageClientDecoding(isFullyFilled, numOfPartialBuffers, false, false, false);
    }

    @ParameterizedTest(name="{index} => isFullyFilled={0}, numOfPartialBuffers={1}")
    @MethodSource(value={"bufferDescriptors"})
    void testClientMessageDecodeWithEmptyBuffers(boolean isFullyFilled, int numOfPartialBuffers) throws Exception {
        this.setup(numOfPartialBuffers);
        this.testNettyMessageClientDecoding(isFullyFilled, numOfPartialBuffers, true, false, false);
    }

    @ParameterizedTest(name="{index} => isFullyFilled={0}, numOfPartialBuffers={1}")
    @MethodSource(value={"bufferDescriptors"})
    void testClientMessageDecodeWithReleasedInputChannel(boolean isFullyFilled, int numOfPartialBuffers) throws Exception {
        this.setup(numOfPartialBuffers);
        this.testNettyMessageClientDecoding(isFullyFilled, numOfPartialBuffers, false, true, false);
    }

    @ParameterizedTest(name="{index} => isFullyFilled={0}, numOfPartialBuffers={1}")
    @MethodSource(value={"bufferDescriptors"})
    void testClientMessageDecodeWithRemovedInputChannel(boolean isFullyFilled, int numOfPartialBuffers) throws Exception {
        this.setup(numOfPartialBuffers);
        this.testNettyMessageClientDecoding(isFullyFilled, numOfPartialBuffers, false, false, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void testNettyMessageClientDecoding(boolean isFullyFilled, int numOfPartialBuffers, boolean hasEmptyBuffer, boolean hasBufferForReleasedChannel, boolean hasBufferForRemovedChannel) throws Exception {
        block5: {
            ByteBuf[] encodedMessages = null;
            List<NettyMessage> decodedMessages = null;
            try {
                List<NettyMessage.BufferResponse> messages = this.createMessageList(isFullyFilled, numOfPartialBuffers, hasEmptyBuffer, hasBufferForReleasedChannel, hasBufferForRemovedChannel);
                encodedMessages = this.encodeMessages(messages);
                List<ByteBuf> partitionedBuffers = this.repartitionMessages(encodedMessages);
                decodedMessages = this.decodeMessages(this.channel, partitionedBuffers);
                this.verifyDecodedMessages(messages, decodedMessages);
                this.releaseBuffers(encodedMessages);
                if (decodedMessages == null) break block5;
            }
            catch (Throwable throwable) {
                this.releaseBuffers(encodedMessages);
                if (decodedMessages != null) {
                    for (NettyMessage nettyMessage : decodedMessages) {
                        if (!(nettyMessage instanceof NettyMessage.BufferResponse)) continue;
                        ((NettyMessage.BufferResponse)nettyMessage).releaseBuffer();
                    }
                }
                throw throwable;
            }
            for (NettyMessage nettyMessage : decodedMessages) {
                if (!(nettyMessage instanceof NettyMessage.BufferResponse)) continue;
                ((NettyMessage.BufferResponse)nettyMessage).releaseBuffer();
            }
        }
    }

    private List<NettyMessage.BufferResponse> createMessageList(boolean isFullyFilled, int numOfPartialBuffers, boolean hasEmptyBuffer, boolean hasBufferForRemovedChannel, boolean hasBufferForReleasedChannel) {
        int seqNumber = 1;
        ArrayList<NettyMessage.BufferResponse> messages = new ArrayList<NettyMessage.BufferResponse>();
        for (int i = 0; i < 4; ++i) {
            this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, this.inputChannelId, Buffer.DataType.DATA_BUFFER, 1024, seqNumber++);
        }
        if (hasEmptyBuffer) {
            this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, this.inputChannelId, Buffer.DataType.DATA_BUFFER, 0, seqNumber++);
        }
        if (hasBufferForReleasedChannel) {
            this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, this.releasedInputChannelId, Buffer.DataType.DATA_BUFFER, 1024, seqNumber++);
        }
        if (hasBufferForRemovedChannel) {
            this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, new InputChannelID(), Buffer.DataType.DATA_BUFFER, 1024, seqNumber++);
        }
        this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, this.inputChannelId, Buffer.DataType.EVENT_BUFFER, 32, seqNumber++);
        this.addBufferResponse(isFullyFilled, numOfPartialBuffers, messages, this.inputChannelId, Buffer.DataType.DATA_BUFFER, 1024, seqNumber);
        return messages;
    }

    private void addBufferResponse(boolean isFullyFilled, int numOfPartialBuffers, List<NettyMessage.BufferResponse> messages, InputChannelID inputChannelId, Buffer.DataType dataType, int bufferSize, int seqNumber) {
        Buffer buffer = this.createDataBuffer(isFullyFilled, numOfPartialBuffers, bufferSize, dataType);
        NettyMessage.BufferResponse bufferResponse = new NettyMessage.BufferResponse(buffer, seqNumber, inputChannelId, 0, isFullyFilled ? numOfPartialBuffers : 0, 1);
        if (isFullyFilled) {
            for (int i = 0; i < numOfPartialBuffers; ++i) {
                bufferResponse.getPartialBufferSizes().add(bufferSize);
            }
        }
        messages.add(bufferResponse);
    }

    private Buffer createDataBuffer(boolean isFullyFilled, int numOfPartialBuffers, int size, Buffer.DataType dataType) {
        if (!isFullyFilled) {
            MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment((int)size);
            NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE, dataType);
            for (int i = 0; i < size / 4; ++i) {
                buffer.writeInt(i);
            }
            return buffer;
        }
        FullyFilledBuffer fullyFilledBuffer = new FullyFilledBuffer(dataType, numOfPartialBuffers * size, false);
        for (int i = 0; i < numOfPartialBuffers; ++i) {
            MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment((int)size);
            NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE, dataType);
            for (int num = 0; num < size / 4; ++num) {
                buffer.writeInt(num);
            }
            fullyFilledBuffer.addPartialBuffer((Buffer)buffer);
        }
        return fullyFilledBuffer;
    }

    private ByteBuf[] encodeMessages(List<NettyMessage.BufferResponse> messages) throws Exception {
        ByteBuf[] encodedMessages = new ByteBuf[messages.size()];
        for (int i = 0; i < messages.size(); ++i) {
            encodedMessages[i] = messages.get(i).write((ByteBufAllocator)ALLOCATOR);
        }
        return encodedMessages;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<ByteBuf> repartitionMessages(ByteBuf[] encodedMessages) {
        ArrayList<ByteBuf> result = new ArrayList<ByteBuf>();
        ByteBuf mergedBuffer1 = null;
        ByteBuf mergedBuffer2 = null;
        try {
            mergedBuffer1 = this.mergeBuffers(encodedMessages, 0, encodedMessages.length / 2);
            mergedBuffer2 = this.mergeBuffers(encodedMessages, encodedMessages.length / 2, encodedMessages.length);
            result.addAll(this.partitionBuffer(mergedBuffer1, 2048));
            result.addAll(this.partitionBuffer(mergedBuffer2, 256));
        }
        catch (Throwable t) {
            try {
                this.releaseBuffers(result.toArray(new ByteBuf[0]));
                ExceptionUtils.rethrow((Throwable)t);
            }
            catch (Throwable throwable) {
                this.releaseBuffers(mergedBuffer1, mergedBuffer2);
                throw throwable;
            }
            this.releaseBuffers(mergedBuffer1, mergedBuffer2);
        }
        this.releaseBuffers(mergedBuffer1, mergedBuffer2);
        return result;
    }

    private ByteBuf mergeBuffers(ByteBuf[] buffers, int start, int end) {
        ByteBuf mergedBuffer = ALLOCATOR.buffer();
        for (int i = start; i < end; ++i) {
            mergedBuffer.writeBytes(buffers[i]);
        }
        return mergedBuffer;
    }

    private List<ByteBuf> partitionBuffer(ByteBuf buffer, int partitionSize) {
        ArrayList<ByteBuf> result = new ArrayList<ByteBuf>();
        try {
            int bufferSize = buffer.readableBytes();
            for (int position = 0; position < bufferSize; position += partitionSize) {
                int endPosition = Math.min(position + partitionSize, bufferSize);
                ByteBuf partitionedBuffer = ALLOCATOR.buffer(endPosition - position);
                partitionedBuffer.writeBytes(buffer, position, endPosition - position);
                result.add(partitionedBuffer);
            }
        }
        catch (Throwable t) {
            this.releaseBuffers(result.toArray(new ByteBuf[0]));
            ExceptionUtils.rethrow((Throwable)t);
        }
        return result;
    }

    private List<NettyMessage> decodeMessages(EmbeddedChannel channel, List<ByteBuf> inputBuffers) {
        Object input;
        for (ByteBuf buffer : inputBuffers) {
            channel.writeInbound(new Object[]{buffer});
        }
        channel.runPendingTasks();
        ArrayList<NettyMessage> decodedMessages = new ArrayList<NettyMessage>();
        while ((input = channel.readInbound()) != null) {
            Assertions.assertThat((Object)input).isInstanceOf(NettyMessage.class);
            decodedMessages.add((NettyMessage)input);
        }
        return decodedMessages;
    }

    private void verifyDecodedMessages(List<NettyMessage.BufferResponse> expectedMessages, List<NettyMessage> decodedMessages) {
        Assertions.assertThat(decodedMessages).hasSameSizeAs(expectedMessages);
        for (int i = 0; i < expectedMessages.size(); ++i) {
            Assertions.assertThat((Object)decodedMessages.get(i)).isInstanceOf(expectedMessages.get(i).getClass());
            NettyMessage.BufferResponse expected = expectedMessages.get(i);
            NettyMessage.BufferResponse actual = (NettyMessage.BufferResponse)decodedMessages.get(i);
            NettyTestUtil.verifyBufferResponseHeader(expected, actual);
            if (expected.bufferSize == 0 || !expected.receiverId.equals((Object)this.inputChannelId)) {
                Assertions.assertThat((Object)actual.getBuffer()).isNull();
                continue;
            }
            Buffer buffer = expected.getBuffer();
            if (buffer instanceof FullyFilledBuffer) {
                Assertions.assertThat((Object)actual.getBuffer()).isEqualTo((Object)buffer.asByteBuf());
                continue;
            }
            Assertions.assertThat((Object)actual.getBuffer()).isEqualTo((Object)buffer);
        }
    }

    private void releaseBuffers(ByteBuf ... buffers) {
        if (buffers != null) {
            for (ByteBuf buffer : buffers) {
                if (buffer == null) continue;
                buffer.release();
            }
        }
    }
}

