/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.driver.internal.async.pool;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.async.connection.ChannelAttributes;
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.metrics.ListenerEvent;
import org.neo4j.driver.internal.metrics.MetricsListener;
import org.neo4j.driver.internal.shaded.io.netty.channel.Channel;
import org.neo4j.driver.internal.shaded.io.netty.channel.ChannelFutureListener;
import org.neo4j.driver.internal.shaded.io.netty.channel.group.ChannelGroup;
import org.neo4j.driver.internal.shaded.io.netty.channel.group.DefaultChannelGroup;
import org.neo4j.driver.internal.shaded.io.netty.channel.pool.ChannelPoolHandler;
import org.neo4j.driver.internal.shaded.io.netty.util.concurrent.EventExecutor;
import org.neo4j.driver.net.ServerAddress;

public class NettyChannelTracker
implements ChannelPoolHandler {
    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    private final Lock read = this.lock.readLock();
    private final Lock write = this.lock.writeLock();
    private final Map<ServerAddress, Integer> addressToInUseChannelCount = new HashMap<ServerAddress, Integer>();
    private final Map<ServerAddress, Integer> addressToIdleChannelCount = new HashMap<ServerAddress, Integer>();
    private final Logger log;
    private final MetricsListener metricsListener;
    private final ChannelFutureListener closeListener = future -> this.channelClosed(future.channel());
    private final ChannelGroup allChannels;

    public NettyChannelTracker(MetricsListener metricsListener, EventExecutor eventExecutor, Logging logging) {
        this(metricsListener, new DefaultChannelGroup("all-connections", eventExecutor), logging);
    }

    public NettyChannelTracker(MetricsListener metricsListener, ChannelGroup channels, Logging logging) {
        this.metricsListener = metricsListener;
        this.log = logging.getLog(this.getClass());
        this.allChannels = channels;
    }

    private void doInWriteLock(Runnable work) {
        try {
            this.write.lock();
            work.run();
        }
        finally {
            this.write.unlock();
        }
    }

    private <T> T retrieveInReadLock(Supplier<T> work) {
        try {
            this.read.lock();
            T t = work.get();
            return t;
        }
        finally {
            this.read.unlock();
        }
    }

    @Override
    public void channelReleased(Channel channel) {
        this.doInWriteLock(() -> {
            this.decrementInUse(channel);
            this.incrementIdle(channel);
            channel.closeFuture().addListener(this.closeListener);
        });
        this.log.debug("Channel [0x%s] released back to the pool", channel.id());
    }

    @Override
    public void channelAcquired(Channel channel) {
        this.doInWriteLock(() -> {
            this.incrementInUse(channel);
            this.decrementIdle(channel);
            channel.closeFuture().removeListener(this.closeListener);
        });
        this.log.debug("Channel [0x%s] acquired from the pool. Local address: %s, remote address: %s", channel.id(), channel.localAddress(), channel.remoteAddress());
    }

    @Override
    public void channelCreated(Channel channel) {
        throw new IllegalStateException("Untraceable channel created.");
    }

    public void channelCreated(Channel channel, ListenerEvent<?> creatingEvent) {
        this.doInWriteLock(() -> this.incrementIdle(channel));
        this.metricsListener.afterCreated(ChannelAttributes.poolId(channel), creatingEvent);
        this.allChannels.add(channel);
        this.log.debug("Channel [0x%s] created. Local address: %s, remote address: %s", channel.id(), channel.localAddress(), channel.remoteAddress());
    }

    public ListenerEvent<?> channelCreating(String poolId) {
        ListenerEvent<?> creatingEvent = this.metricsListener.createListenerEvent();
        this.metricsListener.beforeCreating(poolId, creatingEvent);
        return creatingEvent;
    }

    public void channelFailedToCreate(String poolId) {
        this.metricsListener.afterFailedToCreate(poolId);
    }

    public void channelClosed(Channel channel) {
        this.doInWriteLock(() -> this.decrementIdle(channel));
        this.metricsListener.afterClosed(ChannelAttributes.poolId(channel));
    }

    public int inUseChannelCount(ServerAddress address) {
        return this.retrieveInReadLock(() -> this.addressToInUseChannelCount.getOrDefault(address, 0));
    }

    public int idleChannelCount(ServerAddress address) {
        return this.retrieveInReadLock(() -> this.addressToIdleChannelCount.getOrDefault(address, 0));
    }

    public void prepareToCloseChannels() {
        for (Channel channel : this.allChannels) {
            BoltProtocol protocol = BoltProtocol.forChannel(channel);
            try {
                protocol.prepareToCloseChannel(channel);
            }
            catch (Throwable e) {
                this.log.debug("Failed to prepare to close Channel %s due to error %s. It is safe to ignore this error as the channel will be closed despite if it is successfully prepared to close or not.", channel, e.getMessage());
            }
        }
    }

    private void incrementInUse(Channel channel) {
        this.increment(channel, this.addressToInUseChannelCount);
    }

    private void decrementInUse(Channel channel) {
        BoltServerAddress address = ChannelAttributes.serverAddress(channel);
        if (!this.addressToInUseChannelCount.containsKey(address)) {
            throw new IllegalStateException("No count exists for address '" + address + "' in the 'in use' count");
        }
        Integer count = this.addressToInUseChannelCount.get(address);
        this.addressToInUseChannelCount.put(address, count - 1);
    }

    private void incrementIdle(Channel channel) {
        this.increment(channel, this.addressToIdleChannelCount);
    }

    private void decrementIdle(Channel channel) {
        BoltServerAddress address = ChannelAttributes.serverAddress(channel);
        if (!this.addressToIdleChannelCount.containsKey(address)) {
            throw new IllegalStateException("No count exists for address '" + address + "' in the 'idle' count");
        }
        Integer count = this.addressToIdleChannelCount.get(address);
        this.addressToIdleChannelCount.put(address, count - 1);
    }

    private void increment(Channel channel, Map<ServerAddress, Integer> countMap) {
        BoltServerAddress address = ChannelAttributes.serverAddress(channel);
        Integer count = countMap.computeIfAbsent(address, k -> 0);
        countMap.put(address, count + 1);
    }
}

