/*
 * Copyright 2024 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.openrewrite.remote.java;

import com.fasterxml.jackson.dataformat.cbor.CBORFactory;
import lombok.SneakyThrows;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.java.Log;
import org.jspecify.annotations.Nullable;
import org.openrewrite.ExecutionContext;
import org.openrewrite.InMemoryExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.SourceFile;
import org.openrewrite.internal.Throwing;
import org.openrewrite.remote.*;
import org.openrewrite.scheduling.WatchableExecutionContext;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static org.openrewrite.remote.RemotingMessageType.Request;
import static org.openrewrite.remote.java.CommonHandler.createHandlersMapping;

@Value
@Log
//@RequiredArgsConstructor
public class RemotingServer {
    private static boolean debug = false;
    private static final byte[] MESSAGE_END = new byte[]{(byte) 0x81, (byte) 0x17};
    private static final int BUFFER_SIZE = 8192;

    ByteBuffer receiveBuffer = ByteBuffer.allocate(BUFFER_SIZE);
    ByteBuffer sendBuffer = ByteBuffer.allocate(BUFFER_SIZE);
    byte[] bytes = new byte[BUFFER_SIZE];
    ExecutorService service = Executors.newSingleThreadExecutor();

    int port;
    RemotingContext context;
    long timeout;
    TimeUnit unit;

    CBORFactory factory = new CBORFactory();
    List<Recipe> recipes = new ArrayList<>();

    @NonFinal
    @Nullable
    ServerSocket serverSocket;

    @NonFinal
    @Nullable
    Socket activeSocket;

    @NonFinal
    CountDownLatch started = new CountDownLatch(1);

    @NonFinal
    @Nullable
    SourceFile remoteState;

    Map<String, Supplier<RemotingMessenger.RequestHandler<?>>> handlers;


    public RemotingServer(int port,
                          RemotingContext context,
                          long timeout,
                          TimeUnit unit) {
        this.port = port;
        this.context = context;
        this.timeout = timeout;
        this.unit = unit;
        this.handlers = createHandlersMapping(context, recipes);
    }

    public static RemotingServer create(ExecutionContext ctx,
                                        ClassLoader classLoader) {
        return create(ctx, classLoader, 65432, 2, TimeUnit.MINUTES);
    }

    public static RemotingServer create(ExecutionContext ctx,
                                        ClassLoader classLoader,
                                        int port) {
        return create(ctx, classLoader, port, 2, TimeUnit.MINUTES);
    }

    @SneakyThrows
    public static RemotingServer create(ExecutionContext ctx,
                                        ClassLoader classLoader,
                                        int port,
                                        long timeout,
                                        TimeUnit unit) {
        RemotingServer remotingServer = ctx.getMessage(RemotingServer.class.getName());
        if (remotingServer == null) {
            RemotingContext context = new RemotingContext(classLoader, debug);
            remotingServer = new RemotingServer(port,
                    context,
                    timeout,
                    unit);
            ctx.putMessage(RemotingServer.class.getName(), remotingServer);
        }

        return remotingServer;
    }

    void ensureStarted() {
        if (started.getCount() == 0) {
            return;
        }

        service.execute(() -> {
            try {
                try (ServerSocket serverSocket = new ServerSocket(port, 50, InetAddress.getLoopbackAddress())) {
                    this.serverSocket = serverSocket;
                    started.countDown();
                    System.out.println("Remoting server started on " + port + " ...");

                    long timeoutMillis =
                            System.currentTimeMillis() + unit.toMillis(timeout);
                    while (System.currentTimeMillis() < timeoutMillis) {
                        try {
                            synchronized (RemotingServer.this) {
                                Socket socket = serverSocket.accept();
                                try {
                                    activeSocket = socket;
                                    log.fine("Remoting server accepted " + socket);
                                    ResponseBuffer response = processRequest(socket);
                                    while (response != null) {
                                        writeResponse(socket, response);
                                        response = processRequest(socket);
                                    }
                                } catch (SocketException e) {
                                    log.info("Connection closed");
                                }
                            }
                        } finally {
                            if (Thread.currentThread()
                                    .isInterrupted()) {
                                //noinspection ContinueOrBreakFromFinallyBlock
                                break;
                            }
                        }
                    }

                    started = new CountDownLatch(1);
                }
            } catch (IOException e) {
                Throwing.sneakyThrow(e);
            }
        });

        try {
            if (!started.await(5, TimeUnit.SECONDS)) {
                stop();
                throw new IllegalStateException("Failed to start RemotingServer on " + port);
            }
            RemoteUtils.cleaner.put(this, () -> {
                try {
                    serverSocket.close();
                    log.info("terminating server " + service.shutdownNow());
                    service.awaitTermination(5, TimeUnit.SECONDS);
                } catch (IOException | InterruptedException e) {
                    // ignore
                }
            });
        } catch (InterruptedException ignore) {
        }
    }

    @SneakyThrows
    public void stop() {
        if (serverSocket != null) {
            serverSocket.close();
            if (activeSocket != null) {
                activeSocket.close();
            }
        }
        service.shutdownNow();
        service.awaitTermination(5, TimeUnit.SECONDS);
        this.started = new CountDownLatch(1);
    }

    private void writeResponse(Socket client, ResponseBuffer response) {
        try {
            response.toSocket(client);
        } catch (IOException ignore) {
            // the client probably disconnected
        }
    }

    private @Nullable ResponseBuffer processRequest(Socket socket)
            throws IOException {
        byte[] bytes = new byte[1];
        int read = socket.getInputStream().read(bytes);
        if (read <= 0) {
            return null;
        }
        RemotingMessageType messageType = RemotingMessageType.of(bytes[0]);
        assert messageType == Request;
        ResponseBuffer outputStream = new ResponseBuffer();
        return new RemotingMessenger((CBORFactory) context.objectMapper().getFactory(), handlers,
                (messenger) -> {
                    ExecutionContext ctx = new InMemoryExecutionContext();
                    RemotingExecutionContextView view = RemotingExecutionContextView.view(ctx);
                    view.setRemotingContext(context);
                    view.putMessage(RemotingClient.REMOTING_CLIENT,
                            RemotingClient.create(context, messenger, socket));
                    return ctx;
                }).processRequest(socket) ? outputStream : null;
    }

    public static void main(String[] args) {
        WatchableExecutionContext ctx =
                new WatchableExecutionContext(new InMemoryExecutionContext());
        int port = 65432;
        long timeoutMinutes = TimeUnit.DAYS.toMinutes(1);
        for (String arg : args) {
            if (arg.equals("--debug")) {
                debug = true;
            } else if (arg.charAt(0) >= '0' && arg.charAt(0) <= '9') {
                port = Integer.parseInt(arg);
                timeoutMinutes = 2;
            }
        }

        RemotingServer server = RemotingServer.create(ctx,
                RemotingServer.class.getClassLoader(),
                port,
                timeoutMinutes,
                TimeUnit.MINUTES);
        server.ensureStarted();
    }

    private static final class ResponseBuffer extends ByteArrayOutputStream {

        ResponseBuffer() {
            super(4096);
        }

        public void toSocket(Socket socket) throws IOException {
            socket.getOutputStream().write(buf, 0, count);
        }
    }
}
