/*
 * Decompiled with CFR 0.152.
 */
package org.voltdb.client;

import com.google_voltpatches.common.annotations.VisibleForTesting;
import com.google_voltpatches.common.base.Function;
import com.google_voltpatches.common.base.Optional;
import com.google_voltpatches.common.base.Predicates;
import com.google_voltpatches.common.collect.FluentIterable;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousCloseException;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.net.ssl.SSLEngine;
import javax.security.auth.Subject;
import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.MessageProp;
import org.ietf.jgss.Oid;
import org.voltcore.network.ReverseDNSCache;
import org.voltcore.network.util.ssl.MessagingChannel;
import org.voltdb.ClientResponseImpl;
import org.voltdb.client.ClientAuthScheme;
import org.voltdb.client.ClientResponse;
import org.voltdb.client.DelegatePrincipal;
import org.voltdb.client.ProcedureInvocation;
import org.voltdb.client.TLSHandshaker;
import org.voltdb.utils.SerializationHelper;

public class ConnectionUtil {
    private static final TF m_tf = new TF();
    private static final HashMap<SocketChannel, ExecutorPair> m_executors = new HashMap();
    private static final AtomicLong m_handle = new AtomicLong(Long.MIN_VALUE);
    private static final GSSManager m_gssManager = GSSManager.getInstance();
    private static final Function<Principal, DelegatePrincipal> narrowPrincipal = new Function<Principal, DelegatePrincipal>(){

        @Override
        public DelegatePrincipal apply(Principal input) {
            return (DelegatePrincipal)DelegatePrincipal.class.cast(input);
        }
    };

    public static byte[] getHashedPassword(String password) {
        return ConnectionUtil.getHashedPassword(ClientAuthScheme.HASH_SHA256, password);
    }

    public static byte[] getHashedPassword(ClientAuthScheme scheme, String password) {
        if (password == null) {
            return null;
        }
        MessageDigest md = null;
        try {
            md = MessageDigest.getInstance(ClientAuthScheme.getDigestScheme(scheme));
        }
        catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        byte[] hashedPassword = null;
        hashedPassword = md.digest(password.getBytes(StandardCharsets.UTF_8));
        return hashedPassword;
    }

    public static Object[] getAuthenticatedConnection(String host, String username, byte[] hashedPassword, int port, Subject subject, ClientAuthScheme scheme, long timeoutMillis) throws IOException {
        String service = subject == null ? "database" : "kerberos";
        InetSocketAddress address = new InetSocketAddress(host, port);
        return ConnectionUtil.getAuthenticatedConnectionImpl(service, address, username, hashedPassword, subject, scheme, null, timeoutMillis);
    }

    public static Object[] getAuthenticatedConnection(String host, String username, byte[] hashedPassword, int port, Subject subject, ClientAuthScheme scheme, SSLEngine sslEngine, long timeoutMillis) throws IOException {
        String service = subject == null ? "database" : "kerberos";
        InetSocketAddress address = new InetSocketAddress(host, port);
        return ConnectionUtil.getAuthenticatedConnectionImpl(service, address, username, hashedPassword, subject, scheme, sslEngine, timeoutMillis);
    }

    public static final Optional<DelegatePrincipal> getDelegate(Subject s) {
        if (s == null) {
            return Optional.absent();
        }
        return FluentIterable.from(s.getPrincipals()).filter(Predicates.instanceOf(DelegatePrincipal.class)).transform(narrowPrincipal).first();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Object[] getAuthenticatedConnectionImpl(String service, InetSocketAddress addr, String username, byte[] hashedPassword, Subject subject, ClientAuthScheme scheme, SSLEngine sslEngine, long timeoutMillis) throws IOException {
        DelayedExecutionThread timeoutThread;
        Object[] returnArray = new Object[3];
        boolean success = false;
        if (addr.isUnresolved()) {
            throw new UnknownHostException(addr.getHostName());
        }
        final SocketChannel aChannel = SocketChannel.open(addr);
        returnArray[0] = aChannel;
        assert (aChannel.isConnected());
        if (!aChannel.isConnected()) {
            throw new IOException("Failed to open host " + ReverseDNSCache.hostnameOrAddress(addr.getAddress()));
        }
        if (timeoutMillis > 0L) {
            timeoutThread = new DelayedExecutionThread(timeoutMillis, TimeUnit.MILLISECONDS, new Runnable(){

                @Override
                public void run() {
                    try {
                        aChannel.close();
                    }
                    catch (IOException iOException) {
                        // empty catch block
                    }
                }
            });
            timeoutThread.start();
        } else {
            timeoutThread = null;
        }
        MessagingChannel messagingChannel = null;
        try {
            ByteBuffer loginResponse;
            Object object = aChannel.blockingLock();
            synchronized (object) {
                aChannel.configureBlocking(false);
                aChannel.socket().setTcpNoDelay(true);
            }
            if (sslEngine != null) {
                TLSHandshaker handshaker = new TLSHandshaker(aChannel, sslEngine);
                boolean shookHands = false;
                try {
                    shookHands = handshaker.handshake();
                }
                catch (IOException e) {
                    aChannel.close();
                    throw new IOException("TLS/SSL handshake failed", e);
                }
                if (!shookHands) {
                    aChannel.close();
                    throw new IOException("TLS/SSL handshake failed");
                }
            }
            long[] retvals = new long[4];
            returnArray[1] = retvals;
            messagingChannel = MessagingChannel.get(aChannel, sslEngine);
            Object shookHands = aChannel.blockingLock();
            synchronized (shookHands) {
                aChannel.configureBlocking(true);
                aChannel.socket().setTcpNoDelay(true);
            }
            byte[] serviceBytes = service == null ? null : service.getBytes(StandardCharsets.UTF_8);
            byte[] usernameBytes = username == null ? null : username.getBytes(StandardCharsets.UTF_8);
            int requestSize = 4;
            requestSize += 2;
            requestSize += serviceBytes == null ? 4 : 4 + serviceBytes.length;
            requestSize += usernameBytes == null ? 4 : 4 + usernameBytes.length;
            ByteBuffer b = ByteBuffer.allocate(requestSize += hashedPassword.length);
            b.putInt(requestSize - 4);
            b.put((byte)2);
            b.put((byte)scheme.getValue());
            SerializationHelper.writeVarbinary(serviceBytes, b);
            SerializationHelper.writeVarbinary(usernameBytes, b);
            b.put(hashedPassword);
            b.flip();
            try {
                messagingChannel.writeMessage(b);
            }
            catch (IOException e) {
                throw new IOException("Failed to write authentication message to server.", e);
            }
            if (b.hasRemaining()) {
                throw new IOException("Failed to write authentication message to server.");
            }
            try {
                loginResponse = messagingChannel.readMessage();
            }
            catch (IOException e) {
                throw new IOException("Authentication rejected", e);
            }
            byte version = loginResponse.get();
            byte loginResponseCode = loginResponse.get();
            if (version == 2) {
                byte tag = loginResponseCode;
                if (subject == null) {
                    aChannel.close();
                    throw new IOException("Server requires an authenticated JAAS principal");
                }
                if (tag != 4) {
                    aChannel.close();
                    throw new IOException("Wire protocol format violation error");
                }
                String servicePrincipal = SerializationHelper.getString(loginResponse);
                loginResponse = ConnectionUtil.performAuthenticationHandShake(messagingChannel, subject, servicePrincipal);
                loginResponseCode = loginResponse.get();
            }
            if (loginResponseCode != 0) {
                aChannel.close();
                String reason = "Authentication rejected";
                switch (loginResponseCode) {
                    case 1: {
                        reason = "Server has too many connections";
                        break;
                    }
                    case 2: {
                        reason = "Connection timed out during authentication. The database server may be overloaded.";
                        break;
                    }
                    case 3: {
                        reason = "Wire protocol format violation error";
                        break;
                    }
                    case 4: {
                        reason = "Failed to authenticate to rejoining node";
                        break;
                    }
                    case 5: {
                        reason = "Export not enabled for server";
                        break;
                    }
                    case 6: {
                        reason = "Server requires use of TLS/SSL";
                    }
                }
                throw new IOException(reason);
            }
            retvals[0] = loginResponse.getInt();
            retvals[1] = loginResponse.getLong();
            retvals[2] = loginResponse.getLong();
            retvals[3] = loginResponse.getInt();
            int buildStringLength = loginResponse.getInt();
            byte[] buildStringBytes = new byte[buildStringLength];
            loginResponse.get(buildStringBytes);
            returnArray[2] = new String(buildStringBytes, StandardCharsets.UTF_8);
            Object object2 = aChannel.blockingLock();
            synchronized (object2) {
                aChannel.configureBlocking(false);
                aChannel.socket().setKeepAlive(true);
            }
            success = true;
        }
        catch (AsynchronousCloseException asynchronousCloseException) {
        }
        finally {
            if (messagingChannel != null) {
                messagingChannel.cleanUp();
            }
            if (timeoutThread != null && !timeoutThread.cancel()) {
                throw new IOException("Authentication timed out");
            }
            if (!success) {
                aChannel.close();
            }
        }
        return returnArray;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static boolean isClientLoginHeader(ByteBuffer buf) {
        boolean hasHeader = false;
        try {
            if (buf.remaining() > 4) {
                int msgSize = buf.getInt();
                if (buf.remaining() > 2) {
                    int serviceSize;
                    byte protocolVersion = buf.get();
                    byte passwordHashVersion = buf.get();
                    if (!(protocolVersion != 1 && protocolVersion != 2 || passwordHashVersion != 0 && passwordHashVersion != 1 || buf.remaining() <= 4 || (serviceSize = buf.getInt()) != "database".length() || buf.remaining() <= serviceSize)) {
                        byte[] serviceBytes = new byte[serviceSize];
                        for (int i = 0; i < serviceSize; ++i) {
                            serviceBytes[i] = buf.get();
                        }
                        String service = new String(serviceBytes, StandardCharsets.UTF_8);
                        if (service.equals("database")) {
                            hasHeader = true;
                        }
                    }
                }
            }
        }
        finally {
            buf.rewind();
        }
        return hasHeader;
    }

    private static final void establishSecurityContext(MessagingChannel channel, GSSContext context, Optional<DelegatePrincipal> delegate) throws IOException, GSSException {
        byte[] token;
        ByteBuffer writeBuffer;
        ByteBuffer readBuffer = writeBuffer = ByteBuffer.allocate(4096);
        int msgSize = 0;
        writeBuffer.limit(msgSize);
        while (!context.isEstablished()) {
            token = context.initSecContext(readBuffer.array(), readBuffer.arrayOffset() + readBuffer.position(), readBuffer.remaining());
            if (token != null) {
                msgSize = 6 + token.length;
                writeBuffer.clear().limit(msgSize);
                writeBuffer.putInt(msgSize - 4).put((byte)2).put((byte)5);
                writeBuffer.put(token).flip();
                channel.writeMessage(writeBuffer);
            }
            if (context.isEstablished()) break;
            readBuffer = channel.readMessage();
            byte version = readBuffer.get();
            if (version != 2) {
                throw new IOException("Encountered unexpected authentication protocol version " + version);
            }
            byte tag = readBuffer.get();
            if (tag == 5) continue;
            throw new IOException("Encountered unexpected authentication protocol tag " + tag);
        }
        if (!context.getMutualAuthState()) {
            throw new IOException("Authentication Handshake Failed");
        }
        if (delegate.isPresent() && !context.getConfState()) {
            throw new IOException("Cannot transmit delegate user name securely");
        }
        if (delegate.isPresent()) {
            MessageProp mprop = new MessageProp(0, true);
            writeBuffer.clear().limit(delegate.get().wrappedSize());
            delegate.get().wrap(writeBuffer);
            writeBuffer.flip();
            token = context.wrap(writeBuffer.array(), writeBuffer.arrayOffset() + writeBuffer.position(), writeBuffer.remaining(), mprop);
            msgSize = 6 + token.length;
            writeBuffer.clear().limit(msgSize);
            writeBuffer.putInt(msgSize - 4).put((byte)2).put((byte)5);
            writeBuffer.put(token).flip();
            while (writeBuffer.hasRemaining()) {
                channel.writeMessage(writeBuffer);
            }
        }
    }

    private static final ByteBuffer performAuthenticationHandShake(final MessagingChannel channel, Subject subject, final String serviceName) throws IOException {
        try {
            String subjectPrincipal = subject.getPrincipals().iterator().next().getName();
            final Optional<DelegatePrincipal> delegate = ConnectionUtil.getDelegate(subject);
            if (delegate.isPresent() && !subjectPrincipal.equals(serviceName)) {
                throw new IOException("Delegate authentication is not allowed for user " + delegate.get().getName());
            }
            Subject.doAs(subject, new PrivilegedAction<GSSContext>(){

                @Override
                public GSSContext run() {
                    GSSContext context = null;
                    try {
                        Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
                        Oid krb5PrincipalNameType = new Oid("1.2.840.113554.1.2.2.1");
                        GSSName serverName = m_gssManager.createName(serviceName, krb5PrincipalNameType);
                        context = m_gssManager.createContext(serverName, krb5Oid, null, Integer.MAX_VALUE);
                        context.requestMutualAuth(true);
                        context.requestConf(true);
                        context.requestInteg(true);
                        ConnectionUtil.establishSecurityContext(channel, context, delegate);
                        context.dispose();
                        context = null;
                    }
                    catch (GSSException ex) {
                        throw new RuntimeException(ex);
                    }
                    catch (IOException ex) {
                        throw new RuntimeException(ex);
                    }
                    finally {
                        if (context != null) {
                            try {
                                context.dispose();
                            }
                            catch (Exception exception) {}
                        }
                    }
                    return null;
                }
            });
        }
        catch (SecurityException ex) {
            Throwable cause = ex.getCause();
            if (cause != null && cause instanceof RuntimeException && cause.getCause() != null) {
                cause = cause.getCause();
            } else if (cause == null) {
                cause = ex;
            }
            if (cause instanceof IOException) {
                throw (IOException)IOException.class.cast(cause);
            }
            throw new IOException("Authentication Handshake Failed", cause);
        }
        ByteBuffer loginResponse = channel.readMessage();
        byte version = loginResponse.get();
        if (version != 0) {
            throw new IOException("Encountered unexpected version for the login response message: " + version);
        }
        return loginResponse;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void closeConnection(SocketChannel connection) throws InterruptedException, IOException {
        HashMap<SocketChannel, ExecutorPair> hashMap = m_executors;
        synchronized (hashMap) {
            ExecutorPair p = m_executors.remove(connection);
            assert (p != null);
            p.shutdown();
        }
        connection.close();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static ExecutorPair getExecutorPair(SocketChannel channel) {
        HashMap<SocketChannel, ExecutorPair> hashMap = m_executors;
        synchronized (hashMap) {
            ExecutorPair p = m_executors.get(channel);
            if (p == null) {
                p = new ExecutorPair();
                m_executors.put(channel, p);
            }
            return p;
        }
    }

    public static Future<Long> sendInvocation(SocketChannel channel, String procName, Object ... parameters) {
        ExecutorPair p = ConnectionUtil.getExecutorPair(channel);
        return ConnectionUtil.sendInvocation(p.m_writeExecutor, channel, procName, parameters);
    }

    public static Future<Long> sendInvocation(ExecutorService executor, final SocketChannel channel, final String procName, final Object ... parameters) {
        return executor.submit(new Callable<Long>(){

            @Override
            public Long call() throws Exception {
                long handle = m_handle.getAndIncrement();
                ProcedureInvocation invocation = new ProcedureInvocation(handle, procName, parameters);
                ByteBuffer buf = ByteBuffer.allocate(4 + invocation.getSerializedSize());
                buf.position(4);
                invocation.flattenToBuffer(buf);
                buf.putInt(0, buf.capacity() - 4);
                buf.flip();
                do {
                    channel.write(buf);
                    if (!buf.hasRemaining()) continue;
                    Thread.yield();
                } while (buf.hasRemaining());
                return handle;
            }
        });
    }

    public static Future<ClientResponse> readResponse(SocketChannel channel) {
        ExecutorPair p = ConnectionUtil.getExecutorPair(channel);
        return ConnectionUtil.readResponse(p.m_readExecutor, channel);
    }

    public static Future<ClientResponse> readResponse(ExecutorService executor, final SocketChannel channel) {
        return executor.submit(new Callable<ClientResponse>(){

            @Override
            public ClientResponse call() throws Exception {
                ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
                do {
                    int read;
                    if ((read = channel.read(lengthBuffer)) == -1) {
                        throw new EOFException();
                    }
                    if (!lengthBuffer.hasRemaining()) continue;
                    Thread.yield();
                } while (lengthBuffer.hasRemaining());
                lengthBuffer.flip();
                ByteBuffer message = ByteBuffer.allocate(lengthBuffer.getInt());
                do {
                    int read;
                    if ((read = channel.read(message)) == -1) {
                        throw new EOFException();
                    }
                    if (!lengthBuffer.hasRemaining()) continue;
                    Thread.yield();
                } while (message.hasRemaining());
                message.flip();
                ClientResponseImpl response = new ClientResponseImpl();
                response.initFromBuffer(message);
                return response;
            }
        });
    }

    @VisibleForTesting
    static final class DelayedExecutionThread
    extends Thread {
        private final long m_runAtNanos;
        private final Runnable m_runnable;
        private volatile State m_state = State.NOT_STARTED;

        public DelayedExecutionThread(long delay, TimeUnit unit, Runnable onTimeout) {
            super(null, null, "Delayed Execution Thread " + unit.toMillis(delay) + "ms", 262144L);
            this.m_runAtNanos = unit.toNanos(delay) + System.nanoTime();
            this.m_runnable = onTimeout;
            this.setDaemon(true);
        }

        @Override
        public synchronized void run() {
            long now;
            if (this.m_state == State.CANCELED) {
                return;
            }
            if (this.m_state != State.NOT_STARTED) {
                throw new IllegalStateException("Not in state " + State.NOT_STARTED + ": " + this.m_state);
            }
            this.setState(State.WAITING);
            while (this.m_state == State.WAITING && (now = System.nanoTime()) <= this.m_runAtNanos) {
                try {
                    this.wait(Math.max(1L, TimeUnit.NANOSECONDS.toMillis(now - this.m_runAtNanos)));
                }
                catch (InterruptedException interruptedException) {}
            }
            if (!this.m_state.m_done) {
                this.setState(State.RUNNING);
                if (this.m_runnable != null) {
                    this.m_runnable.run();
                }
                this.setState(State.COMPLETED);
            }
        }

        public synchronized boolean cancel() {
            if (!this.m_state.m_done) {
                this.setState(State.CANCELED);
            }
            return this.m_state == State.CANCELED;
        }

        public synchronized State waitUntilDone() throws InterruptedException {
            while (!this.m_state.m_done) {
                this.wait();
            }
            return this.m_state;
        }

        public State state() {
            return this.m_state;
        }

        private void setState(State state) {
            this.m_state = state;
            if (state.m_done) {
                this.notifyAll();
            }
        }

        public static enum State {
            NOT_STARTED,
            WAITING,
            RUNNING,
            COMPLETED(true),
            CANCELED(true);

            public final boolean m_done;

            private State() {
                this(false);
            }

            private State(boolean done) {
                this.m_done = done;
            }
        }
    }

    public static class ExecutorPair {
        public final ExecutorService m_writeExecutor = Executors.newSingleThreadExecutor(m_tf);
        public final ExecutorService m_readExecutor = Executors.newSingleThreadExecutor(m_tf);

        private void shutdown() throws InterruptedException {
            this.m_readExecutor.shutdownNow();
            this.m_writeExecutor.shutdownNow();
            this.m_readExecutor.awaitTermination(1L, TimeUnit.DAYS);
            this.m_writeExecutor.awaitTermination(1L, TimeUnit.DAYS);
        }
    }

    private static class TF
    implements ThreadFactory {
        private TF() {
        }

        @Override
        public Thread newThread(Runnable r) {
            return new Thread(null, r, "Yet another thread", 65536L);
        }
    }
}

