/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.api.operators.collect;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.SerializedListAccumulator;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.sink.legacy.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.api.operators.collect.CollectCoordinationRequest;
import org.apache.flink.streaming.api.operators.collect.CollectCoordinationResponse;
import org.apache.flink.streaming.api.operators.collect.CollectSinkAddressEvent;
import org.apache.flink.streaming.api.operators.collect.CollectSinkOperatorFactory;
import org.apache.flink.util.NetUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public class CollectSinkFunction<IN>
extends RichSinkFunction<IN>
implements CheckpointedFunction,
CheckpointListener {
    private static final long serialVersionUID = 1L;
    private static final Logger LOG = LoggerFactory.getLogger(CollectSinkFunction.class);
    private final TypeSerializer<IN> serializer;
    private final long maxBytesPerBatch;
    private final long bufferSizeLimitBytes;
    private final String accumulatorName;
    private transient OperatorEventGateway eventGateway;
    private transient LinkedList<byte[]> buffer;
    private transient long currentBufferBytes;
    private transient ReentrantLock bufferLock;
    private transient Condition bufferCanAddNextResultCondition;
    private transient long invokingRecordBytes;
    private transient String version;
    private transient long offset;
    private transient long lastCheckpointedOffset;
    private transient ServerThread serverThread;
    private transient ListState<byte[]> bufferState;
    private transient ListState<Long> offsetState;
    private transient SortedMap<Long, Long> uncompletedCheckpointMap;

    public CollectSinkFunction(TypeSerializer<IN> serializer, long maxBytesPerBatch, String accumulatorName) {
        this.serializer = serializer;
        this.maxBytesPerBatch = maxBytesPerBatch;
        this.bufferSizeLimitBytes = maxBytesPerBatch * 2L;
        this.accumulatorName = accumulatorName;
    }

    public long getMaxBytesPerBatch() {
        return this.maxBytesPerBatch;
    }

    private void initBuffer() {
        if (this.buffer != null) {
            return;
        }
        this.buffer = new LinkedList();
        this.currentBufferBytes = 0L;
        this.bufferLock = new ReentrantLock();
        this.bufferCanAddNextResultCondition = this.bufferLock.newCondition();
        this.lastCheckpointedOffset = this.offset = 0L;
    }

    @Override
    public void initializeState(FunctionInitializationContext context) throws Exception {
        this.initBuffer();
        this.bufferState = context.getOperatorStateStore().getListState(new ListStateDescriptor("bufferState", (TypeSerializer)BytePrimitiveArraySerializer.INSTANCE));
        for (byte[] value : (Iterable)this.bufferState.get()) {
            this.buffer.add(value);
            this.currentBufferBytes += (long)value.length;
        }
        this.offsetState = context.getOperatorStateStore().getListState(new ListStateDescriptor("offsetState", Long.class));
        Iterator iterator = ((Iterable)this.offsetState.get()).iterator();
        while (iterator.hasNext()) {
            long value;
            this.offset = value = ((Long)iterator.next()).longValue();
        }
        this.lastCheckpointedOffset = this.offset;
        LOG.info("Initializing collect sink state with offset = " + this.lastCheckpointedOffset + ", buffered results bytes = " + this.currentBufferBytes);
        this.uncompletedCheckpointMap = new TreeMap<Long, Long>();
    }

    @Override
    public void snapshotState(FunctionSnapshotContext context) throws Exception {
        this.bufferLock.lock();
        try {
            this.bufferState.update(this.buffer);
            this.offsetState.update(Collections.singletonList(this.offset));
            this.uncompletedCheckpointMap.put(context.getCheckpointId(), this.offset);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Checkpoint begin with checkpointId = " + context.getCheckpointId() + ", lastCheckpointedOffset = " + this.lastCheckpointedOffset + ", buffered results bytes = " + this.currentBufferBytes);
            }
        }
        finally {
            this.bufferLock.unlock();
        }
    }

    public void open(OpenContext openContext) throws Exception {
        Preconditions.checkState((this.getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks() == 1 ? 1 : 0) != 0, (Object)"The parallelism of CollectSinkFunction must be 1");
        this.initBuffer();
        this.version = UUID.randomUUID().toString();
        this.serverThread = new ServerThread(this.serializer);
        this.serverThread.start();
        Preconditions.checkNotNull((Object)this.eventGateway, (String)"Operator event gateway hasn't been set");
        InetSocketAddress address = this.serverThread.getServerSocketAddress();
        LOG.info("Collect sink server established, address = " + String.valueOf(address));
        CollectSinkAddressEvent addressEvent = new CollectSinkAddressEvent(address);
        this.eventGateway.sendEventToCoordinator(addressEvent);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void invoke(IN value, SinkFunction.Context context) throws Exception {
        this.bufferLock.lock();
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper((OutputStream)baos);
            this.serializer.serialize(value, (DataOutputView)wrapper);
            this.invokingRecordBytes = baos.size();
            if (this.invokingRecordBytes > this.maxBytesPerBatch) {
                throw new RuntimeException("Record size is too large for CollectSinkFunction. Record size is " + this.invokingRecordBytes + " bytes, but max bytes per batch is only " + this.maxBytesPerBatch + " bytes. Please consider increasing max bytes per batch value by setting " + CollectSinkOperatorFactory.MAX_BATCH_SIZE.key());
            }
            if (this.currentBufferBytes + this.invokingRecordBytes > this.bufferSizeLimitBytes) {
                this.bufferCanAddNextResultCondition.await();
            }
            this.buffer.add(baos.toByteArray());
            this.currentBufferBytes += (long)baos.size();
        }
        finally {
            this.bufferLock.unlock();
        }
    }

    public void close() throws Exception {
        this.serverThread.close();
        this.serverThread.join();
    }

    public void accumulateFinalResults() throws Exception {
        this.bufferLock.lock();
        try {
            SerializedListAccumulator accumulator = new SerializedListAccumulator();
            accumulator.add((Object)CollectSinkFunction.serializeAccumulatorResult(this.offset, this.version, this.lastCheckpointedOffset, this.buffer), (TypeSerializer)BytePrimitiveArraySerializer.INSTANCE);
            this.getRuntimeContext().addAccumulator(this.accumulatorName, (Accumulator)accumulator);
        }
        finally {
            this.bufferLock.unlock();
        }
    }

    public void notifyCheckpointComplete(long checkpointId) {
        this.lastCheckpointedOffset = (Long)this.uncompletedCheckpointMap.get(checkpointId);
        this.uncompletedCheckpointMap.headMap(checkpointId + 1L).clear();
        if (LOG.isDebugEnabled()) {
            LOG.debug("Checkpoint complete with checkpointId = " + checkpointId + ", lastCheckpointedOffset = " + this.lastCheckpointedOffset);
        }
    }

    public void notifyCheckpointAborted(long checkpointId) {
    }

    public void setOperatorEventGateway(OperatorEventGateway eventGateway) {
        this.eventGateway = eventGateway;
    }

    @VisibleForTesting
    public static byte[] serializeAccumulatorResult(long offset, String version, long lastCheckpointedOffset, List<byte[]> buffer) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper((OutputStream)baos);
        wrapper.writeLong(offset);
        CollectCoordinationResponse finalResponse = new CollectCoordinationResponse(version, lastCheckpointedOffset, buffer);
        finalResponse.serialize((DataOutputView)wrapper);
        return baos.toByteArray();
    }

    public static Tuple2<Long, CollectCoordinationResponse> deserializeAccumulatorResult(byte[] serializedAccResults) throws IOException {
        ByteArrayInputStream bais = new ByteArrayInputStream(serializedAccResults);
        DataInputViewStreamWrapper wrapper = new DataInputViewStreamWrapper((InputStream)bais);
        long token = wrapper.readLong();
        CollectCoordinationResponse finalResponse = new CollectCoordinationResponse((DataInputView)wrapper);
        return Tuple2.of((Object)token, (Object)finalResponse);
    }

    private class ServerThread
    extends Thread {
        private final TypeSerializer<IN> serializer;
        private final ServerSocket serverSocket;
        private boolean running;
        private Socket connection;
        private DataInputViewStreamWrapper inStream;
        private DataOutputViewStreamWrapper outStream;

        private ServerThread(TypeSerializer<IN> serializer) throws Exception {
            this.serializer = serializer.duplicate();
            this.serverSocket = new ServerSocket(this.getPort(), 0, this.getBindAddress());
            this.running = true;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            while (this.running) {
                try {
                    if (this.connection == null) {
                        this.connection = NetUtils.acceptWithoutTimeout((ServerSocket)this.serverSocket);
                        this.inStream = new DataInputViewStreamWrapper(this.connection.getInputStream());
                        this.outStream = new DataOutputViewStreamWrapper(this.connection.getOutputStream());
                        LOG.info("Coordinator connection received");
                    }
                    CollectCoordinationRequest request = new CollectCoordinationRequest((DataInputView)this.inStream);
                    String requestVersion = request.getVersion();
                    long requestOffset = request.getOffset();
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Request received, version = " + requestVersion + ", offset = " + requestOffset);
                        LOG.debug("Expecting version = " + CollectSinkFunction.this.version + ", offset = " + CollectSinkFunction.this.offset);
                    }
                    if (!CollectSinkFunction.this.version.equals(requestVersion) || requestOffset < CollectSinkFunction.this.offset) {
                        LOG.info("Invalid request. Received version = " + requestVersion + ", offset = " + requestOffset + ", while expected version = " + CollectSinkFunction.this.version + ", offset = " + CollectSinkFunction.this.offset);
                        this.sendBackResults(Collections.emptyList());
                        continue;
                    }
                    ArrayList<byte[]> nextBatch = new ArrayList<byte[]>();
                    CollectSinkFunction.this.bufferLock.lock();
                    try {
                        byte[] value;
                        int ackedNum = (int)(requestOffset - CollectSinkFunction.this.offset);
                        for (int i = 0; i < ackedNum && !CollectSinkFunction.this.buffer.isEmpty(); ++i) {
                            byte[] removed = CollectSinkFunction.this.buffer.removeFirst();
                            CollectSinkFunction.this.currentBufferBytes -= (long)removed.length;
                            ++CollectSinkFunction.this.offset;
                        }
                        long totalBytes = 0L;
                        Iterator iterator = CollectSinkFunction.this.buffer.iterator();
                        while (iterator.hasNext() && totalBytes + (long)(value = (byte[])iterator.next()).length <= CollectSinkFunction.this.maxBytesPerBatch) {
                            nextBatch.add(value);
                            totalBytes += (long)value.length;
                        }
                        if (CollectSinkFunction.this.currentBufferBytes + CollectSinkFunction.this.invokingRecordBytes <= CollectSinkFunction.this.bufferSizeLimitBytes) {
                            CollectSinkFunction.this.bufferCanAddNextResultCondition.signal();
                        }
                    }
                    finally {
                        CollectSinkFunction.this.bufferLock.unlock();
                    }
                    this.sendBackResults(nextBatch);
                }
                catch (Exception e) {
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Collect sink server encounters an exception", (Throwable)e);
                    }
                    this.closeCurrentConnection();
                }
            }
        }

        private void close() {
            this.running = false;
            this.closeServerSocket();
            this.closeCurrentConnection();
        }

        private InetSocketAddress getServerSocketAddress() {
            StreamingRuntimeContext streamingContext = this.getStreamingRuntimeContext();
            String taskManagerAddress = streamingContext.getTaskManagerRuntimeInfo().getTaskManagerExternalAddress();
            return new InetSocketAddress(taskManagerAddress, this.serverSocket.getLocalPort());
        }

        private InetAddress getBindAddress() {
            StreamingRuntimeContext streamingContext = this.getStreamingRuntimeContext();
            String bindAddress = streamingContext.getTaskManagerRuntimeInfo().getTaskManagerBindAddress();
            if (bindAddress != null) {
                try {
                    return InetAddress.getByName(bindAddress);
                }
                catch (UnknownHostException e) {
                    return null;
                }
            }
            return null;
        }

        private int getPort() {
            return (Integer)this.getStreamingRuntimeContext().getTaskManagerRuntimeInfo().getConfiguration().get(TaskManagerOptions.COLLECT_PORT);
        }

        private StreamingRuntimeContext getStreamingRuntimeContext() {
            RuntimeContext context = CollectSinkFunction.this.getRuntimeContext();
            Preconditions.checkState((boolean)(context instanceof StreamingRuntimeContext), (Object)"CollectSinkFunction can only be used in StreamTask");
            StreamingRuntimeContext streamingContext = (StreamingRuntimeContext)context;
            return streamingContext;
        }

        private void sendBackResults(List<byte[]> serializedResults) throws IOException {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Sending back " + serializedResults.size() + " results");
            }
            CollectCoordinationResponse response = new CollectCoordinationResponse(CollectSinkFunction.this.version, CollectSinkFunction.this.lastCheckpointedOffset, serializedResults);
            response.serialize((DataOutputView)this.outStream);
        }

        private void closeCurrentConnection() {
            try {
                if (this.connection != null) {
                    this.connection.close();
                    this.connection = null;
                }
            }
            catch (Exception e) {
                LOG.warn("Error occurs when closing client connections in CollectSinkFunction", (Throwable)e);
            }
        }

        private void closeServerSocket() {
            try {
                this.serverSocket.close();
            }
            catch (Exception e) {
                LOG.warn("Error occurs when closing server in CollectSinkFunction", (Throwable)e);
            }
        }
    }
}

