/*
 * Decompiled with CFR 0.152.
 */
package wiremock.org.eclipse.jetty.server.handler;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import wiremock.org.eclipse.jetty.http.HttpStatus;
import wiremock.org.eclipse.jetty.io.CyclicTimeouts;
import wiremock.org.eclipse.jetty.server.Handler;
import wiremock.org.eclipse.jetty.server.Request;
import wiremock.org.eclipse.jetty.server.Response;
import wiremock.org.eclipse.jetty.server.Server;
import wiremock.org.eclipse.jetty.server.handler.ConditionalHandler;
import wiremock.org.eclipse.jetty.util.Callback;
import wiremock.org.eclipse.jetty.util.NanoTime;
import wiremock.org.eclipse.jetty.util.TypeUtil;
import wiremock.org.eclipse.jetty.util.annotation.ManagedObject;
import wiremock.org.eclipse.jetty.util.annotation.Name;
import wiremock.org.eclipse.jetty.util.thread.AutoLock;
import wiremock.org.eclipse.jetty.util.thread.Scheduler;
import wiremock.org.slf4j.Logger;
import wiremock.org.slf4j.LoggerFactory;

@ManagedObject(value="DoS Prevention Handler")
public class DoSHandler
extends ConditionalHandler.ElseNext {
    private static final Logger LOG = LoggerFactory.getLogger(DoSHandler.class);
    public static final Function<Request, String> ID_FROM_REMOTE_ADDRESS_PORT = request -> {
        SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
        if (remoteSocketAddress instanceof InetSocketAddress) {
            InetSocketAddress inetSocketAddress = (InetSocketAddress)remoteSocketAddress;
            return inetSocketAddress.toString();
        }
        return remoteSocketAddress.toString();
    };
    public static final Function<Request, String> ID_FROM_REMOTE_ADDRESS = request -> {
        SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
        if (remoteSocketAddress instanceof InetSocketAddress) {
            InetSocketAddress inetSocketAddress = (InetSocketAddress)remoteSocketAddress;
            return inetSocketAddress.getAddress().toString();
        }
        return remoteSocketAddress.toString();
    };
    public static final Function<Request, String> ID_FROM_REMOTE_PORT = request -> {
        SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
        if (remoteSocketAddress instanceof InetSocketAddress) {
            InetSocketAddress inetSocketAddress = (InetSocketAddress)remoteSocketAddress;
            return Integer.toString(inetSocketAddress.getPort());
        }
        return remoteSocketAddress.toString();
    };
    public static final Function<Request, String> ID_FROM_CONNECTION = request -> request.getConnectionMetaData().getId();
    private final Map<String, Tracker> _trackers = new ConcurrentHashMap<String, Tracker>();
    private final Function<Request, String> _clientIdFn;
    private final Tracker.Factory _trackerFactory;
    private final Request.Handler _rejectHandler;
    private final int _maxTrackers;
    private final boolean _rejectUntracked;
    private CyclicTimeouts<Tracker> _cyclicTimeouts;

    public DoSHandler(@Name(value="trackerFactory") Tracker.Factory trackerFactory) {
        this(null, trackerFactory, null, -1);
    }

    public DoSHandler(@Name(value="clientIdFn") Function<Request, String> clientIdFn, @Name(value="trackerFactory") Tracker.Factory trackerFactory, @Name(value="rejectHandler") Request.Handler rejectHandler, @Name(value="maxTrackers") int maxTrackers) {
        this(null, clientIdFn, trackerFactory, rejectHandler, maxTrackers);
    }

    public DoSHandler(@Name(value="handler") Handler handler, @Name(value="clientIdFn") Function<Request, String> clientIdFn, @Name(value="trackerFactory") Tracker.Factory trackerFactory, @Name(value="rejectHandler") Request.Handler rejectHandler, @Name(value="maxTrackers") int maxTrackers) {
        this(handler, clientIdFn, trackerFactory, rejectHandler, maxTrackers, false);
    }

    public DoSHandler(@Name(value="handler") Handler handler, @Name(value="clientIdFn") Function<Request, String> clientIdFn, @Name(value="trackerFactory") Tracker.Factory trackerFactory, @Name(value="rejectHandler") Request.Handler rejectHandler, @Name(value="maxTrackers") int maxTrackers, @Name(value="rejectUntracked") boolean rejectUntracked) {
        super(handler);
        this.installBean(this._trackers);
        this._clientIdFn = Objects.requireNonNullElse(clientIdFn, ID_FROM_REMOTE_ADDRESS);
        this.installBean(this._clientIdFn);
        this._trackerFactory = Objects.requireNonNull(trackerFactory);
        this.installBean(this._trackerFactory);
        this._maxTrackers = maxTrackers < 0 ? 100000 : maxTrackers;
        this._rejectHandler = Objects.requireNonNullElseGet(rejectHandler, StatusRejectHandler::new);
        this.installBean(this._rejectHandler);
        this._rejectUntracked = rejectUntracked;
    }

    @Override
    public void setServer(Server server) {
        super.setServer(server);
        Request.Handler handler = this._rejectHandler;
        if (handler instanceof Handler) {
            Handler handler2 = (Handler)handler;
            handler2.setServer(server);
        }
    }

    @Override
    protected boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception {
        Tracker tracker;
        String id = this._clientIdFn.apply(request);
        if (id == null) {
            id = "";
        }
        if ((tracker = this._trackers.computeIfAbsent(id, this::newTracker)) == null) {
            return this._rejectUntracked ? this._rejectHandler.handle(request, response, callback) : this.nextHandler(request, response, callback);
        }
        boolean allowed = tracker.onRequest(NanoTime.now());
        if (LOG.isDebugEnabled()) {
            LOG.debug("allowed={} {}", (Object)allowed, (Object)tracker);
        }
        if (allowed) {
            return this.nextHandler(request, response, callback);
        }
        return this._rejectHandler.handle(request, response, callback);
    }

    Tracker newTracker(String id) {
        if (this._maxTrackers > 0 && this._trackers.size() >= this._maxTrackers) {
            return null;
        }
        Tracker tracker = this._trackerFactory.newTracker(id);
        this._cyclicTimeouts.schedule(tracker);
        return tracker;
    }

    @Override
    protected void doStart() throws Exception {
        this._cyclicTimeouts = new CyclicTimeouts<Tracker>(this.getServer().getScheduler()){

            @Override
            protected Iterator<Tracker> iterator() {
                return DoSHandler.this._trackers.values().iterator();
            }

            @Override
            protected boolean onExpired(Tracker tracker) {
                return true;
            }
        };
        this.addBean(this._cyclicTimeouts);
        super.doStart();
    }

    @Override
    protected void doStop() throws Exception {
        super.doStop();
        this.removeBean(this._cyclicTimeouts);
        this._cyclicTimeouts.destroy();
        this._cyclicTimeouts = null;
    }

    public static interface Tracker
    extends CyclicTimeouts.Expirable {
        public boolean onRequest(long var1);

        public static interface Factory {
            public Tracker newTracker(String var1);
        }
    }

    public static class DelayedRejectHandler
    extends Handler.Abstract {
        private final AutoLock _lock = new AutoLock();
        private final Deque<Exchange> _delayQueue = new ArrayDeque<Exchange>();
        private final int _maxDelayQueue;
        private final long _delayMs;
        private final Request.Handler _reject;
        private Scheduler _scheduler;

        public DelayedRejectHandler() {
            this(-1L, -1, null);
        }

        public DelayedRejectHandler(@Name(value="delayMs") long delayMs, @Name(value="maxDelayQueue") int maxDelayQueue, @Name(value="reject") Request.Handler reject) {
            this._delayMs = delayMs >= 0L ? delayMs : 1000L;
            this._maxDelayQueue = maxDelayQueue >= 0 ? maxDelayQueue : 1000;
            this._reject = Objects.requireNonNullElseGet(reject, () -> new StatusRejectHandler(429));
        }

        @Override
        protected void doStart() throws Exception {
            super.doStart();
            this._scheduler = this.getServer().getScheduler();
            this.addBean(this._scheduler);
        }

        @Override
        protected void doStop() throws Exception {
            super.doStop();
            this.removeBean(this._scheduler);
            this._scheduler = null;
        }

        @Override
        public boolean handle(Request request, Response response, Callback callback) throws Exception {
            ArrayList<Exchange> rejects = null;
            try (AutoLock ignored = this._lock.lock();){
                while (this._delayQueue.size() >= this._maxDelayQueue) {
                    Exchange exchange = this._delayQueue.removeFirst();
                    if (rejects == null) {
                        rejects = new ArrayList<Exchange>();
                    }
                    rejects.add(exchange);
                }
                if (this._delayQueue.isEmpty()) {
                    this._scheduler.schedule(this::onTick, this._delayMs / 2L, TimeUnit.MILLISECONDS);
                }
                this._delayQueue.addLast(new Exchange(request, response, callback));
            }
            this.reject(rejects);
            return true;
        }

        private void onTick() {
            long expired = NanoTime.now() - TimeUnit.MILLISECONDS.toNanos(this._delayMs);
            ArrayList<Exchange> rejects = null;
            try (AutoLock ignored = this._lock.lock();){
                Iterator<Exchange> iterator = this._delayQueue.iterator();
                while (iterator.hasNext()) {
                    Exchange exchange = iterator.next();
                    if (!NanoTime.isBeforeOrSame(exchange.request.getBeginNanoTime(), expired)) continue;
                    iterator.remove();
                    if (rejects == null) {
                        rejects = new ArrayList<Exchange>();
                    }
                    rejects.add(exchange);
                }
                if (!this._delayQueue.isEmpty()) {
                    this._scheduler.schedule(this::onTick, this._delayMs / 2L, TimeUnit.MILLISECONDS);
                }
            }
            this.reject(rejects);
        }

        private void reject(List<Exchange> rejects) {
            if (rejects != null) {
                for (Exchange exchange : rejects) {
                    try {
                        if (this._reject.handle(exchange.request, exchange.response, exchange.callback)) continue;
                        exchange.callback.failed(new RejectedExecutionException());
                    }
                    catch (Throwable t) {
                        exchange.callback.failed(t);
                    }
                }
            }
        }

        private record Exchange(Request request, Response response, Callback callback) {
        }
    }

    public static class StatusRejectHandler
    implements Request.Handler {
        private final int _status;

        public StatusRejectHandler() {
            this(-1);
        }

        public StatusRejectHandler(int status) {
            int n = this._status = status >= 0 ? status : 429;
            if (this._status != 0 && this._status != 200 && !HttpStatus.isClientError(this._status) && !HttpStatus.isServerError(this._status)) {
                throw new IllegalArgumentException("status must be a client or server error");
            }
        }

        @Override
        public boolean handle(Request request, Response response, Callback callback) throws Exception {
            if (this._status == 0) {
                callback.failed(new RejectedExecutionException());
            } else {
                Response.writeError(request, response, callback, this._status);
            }
            return true;
        }
    }

    public static class LeakingBucketTrackerFactory
    implements Tracker.Factory {
        private final int _maxRequestsPerSecond;
        private final int _bucketSize;
        private final long _nanosPerDrip;
        private final long _idleTimeout;

        public LeakingBucketTrackerFactory(@Name(value="maxRequestsPerSecond") int maxRequestsPerSecond) {
            this(maxRequestsPerSecond, -1, null);
        }

        public LeakingBucketTrackerFactory(@Name(value="maxRequestsPerSecond") int maxRequestsPerSecond, @Name(value="bucketSize") int bucketSize, @Name(value="idleTimeout") Duration idleTimeout) {
            this._maxRequestsPerSecond = maxRequestsPerSecond;
            this._nanosPerDrip = TimeUnit.SECONDS.toNanos(1L) / (long)this._maxRequestsPerSecond;
            this._bucketSize = bucketSize < 0 ? this._maxRequestsPerSecond : bucketSize;
            this._idleTimeout = idleTimeout == null || idleTimeout.isNegative() ? 0L : idleTimeout.toNanos();
        }

        @Override
        public Tracker newTracker(String id) {
            return new LeakingBucketTracker(id);
        }

        private class LeakingBucketTracker
        implements Tracker {
            private final AutoLock _lock = new AutoLock();
            private final String _id;
            private long _lastDripNanoTime;
            private long _expireNanoTime;
            private int _bucket;

            public LeakingBucketTracker(String id) {
                long now;
                this._id = id;
                this._lastDripNanoTime = now = NanoTime.now();
                this._expireNanoTime = now + LeakingBucketTrackerFactory.this._nanosPerDrip + LeakingBucketTrackerFactory.this._idleTimeout;
                if (this._expireNanoTime == Long.MAX_VALUE) {
                    ++this._expireNanoTime;
                }
            }

            @Override
            public long getExpireNanoTime() {
                try (AutoLock ignored = this._lock.lock();){
                    long l = this._expireNanoTime;
                    return l;
                }
            }

            @Override
            public boolean onRequest(long now) {
                try (AutoLock ignored = this._lock.lock();){
                    long elapsedSinceLastDrip = NanoTime.elapsed(this._lastDripNanoTime, now);
                    long drips = elapsedSinceLastDrip / LeakingBucketTrackerFactory.this._nanosPerDrip;
                    this._lastDripNanoTime += drips * LeakingBucketTrackerFactory.this._nanosPerDrip;
                    this._bucket = Math.min(LeakingBucketTrackerFactory.this._bucketSize, Math.toIntExact(Math.max(0L, (long)this._bucket - drips) + 1L));
                    this._expireNanoTime = now + (long)this._bucket * LeakingBucketTrackerFactory.this._nanosPerDrip + LeakingBucketTrackerFactory.this._idleTimeout;
                    boolean bl = this._bucket < LeakingBucketTrackerFactory.this._bucketSize;
                    return bl;
                }
            }

            public String toString() {
                try (AutoLock ignored = this._lock.lock();){
                    String string = "%s@%s{%d/%d}".formatted(TypeUtil.toShortName(this.getClass()), this._id, this._bucket, LeakingBucketTrackerFactory.this._maxRequestsPerSecond);
                    return string;
                }
            }
        }
    }
}

