/*
 * Decompiled with CFR 0.152.
 */
package org.apache.knox.gateway.websockets;

import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.security.KeyStore;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.ClientEndpointConfig;
import org.apache.commons.lang3.StringUtils;
import org.apache.knox.gateway.config.GatewayConfig;
import org.apache.knox.gateway.i18n.messages.MessagesFactory;
import org.apache.knox.gateway.services.GatewayServices;
import org.apache.knox.gateway.services.ServiceType;
import org.apache.knox.gateway.services.registry.ServiceDefEntry;
import org.apache.knox.gateway.services.registry.ServiceDefinitionRegistry;
import org.apache.knox.gateway.services.registry.ServiceRegistry;
import org.apache.knox.gateway.services.security.KeystoreService;
import org.apache.knox.gateway.services.security.KeystoreServiceException;
import org.apache.knox.gateway.webshell.WebshellWebSocketAdapter;
import org.apache.knox.gateway.websockets.JWTValidator;
import org.apache.knox.gateway.websockets.JWTValidatorFactory;
import org.apache.knox.gateway.websockets.ProxyWebSocketAdapter;
import org.apache.knox.gateway.websockets.WebsocketLogMessages;
import org.eclipse.jetty.websocket.server.WebSocketHandler;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

public class GatewayWebsocketHandler
extends WebSocketHandler
implements WebSocketCreator {
    private static final WebsocketLogMessages LOG = (WebsocketLogMessages)MessagesFactory.get(WebsocketLogMessages.class);
    public static final String WEBSOCKET_PROTOCOL_STRING = "ws://";
    public static final String SECURE_WEBSOCKET_PROTOCOL_STRING = "wss://";
    static final String REGEX_SPLIT_CONTEXT = "^((?:[^/]*/){2}[^/]*)";
    static final String REGEX_SPLIT_SERVICE_PATH = "^((?:[^/]*/){3}[^/]*)";
    static final String REGEX_WEBSHELL_REQUEST_PATH = "^(wss://|ws://)[^/]+/[^/]+/webshell$";
    private static final int POOL_SIZE = 10;
    private final AtomicInteger concurrentWebshells;
    private final ExecutorService pool;
    final GatewayConfig config;
    final GatewayServices services;

    public GatewayWebsocketHandler(GatewayConfig config, GatewayServices services) {
        this.config = config;
        this.services = services;
        this.pool = Executors.newFixedThreadPool(10);
        this.concurrentWebshells = new AtomicInteger(0);
    }

    public void configure(WebSocketServletFactory factory) {
        factory.setCreator((WebSocketCreator)this);
        factory.getPolicy().setMaxTextMessageSize(this.config.getWebsocketMaxTextMessageSize());
        factory.getPolicy().setMaxBinaryMessageSize(this.config.getWebsocketMaxBinaryMessageSize());
        factory.getPolicy().setMaxBinaryMessageBufferSize(this.config.getWebsocketMaxBinaryMessageBufferSize());
        factory.getPolicy().setMaxTextMessageBufferSize(this.config.getWebsocketMaxTextMessageBufferSize());
        factory.getPolicy().setInputBufferSize(this.config.getWebsocketInputBufferSize());
        factory.getPolicy().setAsyncWriteTimeout((long)this.config.getWebsocketAsyncWriteTimeout());
        factory.getPolicy().setIdleTimeout((long)this.config.getWebsocketIdleTimeout());
    }

    private Boolean isWebshellRequest(URI requestURI) {
        return requestURI.toString().matches(REGEX_WEBSHELL_REQUEST_PATH);
    }

    private WebshellWebSocketAdapter handleWebshellRequest(ServletUpgradeRequest req) {
        if (this.config.isWebShellEnabled()) {
            if (this.concurrentWebshells.get() >= this.config.getMaximumConcurrentWebshells()) {
                throw new RuntimeException("Number of allowed concurrent Web Shell sessions exceeded");
            }
            JWTValidator jwtValidator = JWTValidatorFactory.create(req, this.services, this.config);
            if (jwtValidator.validate()) {
                return new WebshellWebSocketAdapter(this.pool, this.config, jwtValidator, this.concurrentWebshells);
            }
            throw new RuntimeException("No valid token found for Web Shell connection");
        }
        throw new RuntimeException("Web Shell not enabled");
    }

    public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
        try {
            URI requestURI = req.getRequestURI();
            if (this.isWebshellRequest(requestURI).booleanValue()) {
                return this.handleWebshellRequest(req);
            }
            String backendURL = this.getMatchedBackendURL(requestURI);
            LOG.debugLog("Generated backend URL for websocket connection: " + backendURL);
            ClientEndpointConfig clientConfig = this.getClientEndpointConfig(req);
            clientConfig.getUserProperties().put("org.apache.knox.gateway.websockets.truststore", this.getTruststore());
            return new ProxyWebSocketAdapter(URI.create(backendURL), this.pool, clientConfig, this.config);
        }
        catch (Exception e) {
            LOG.failedCreatingWebSocket(e);
            throw new RuntimeException(e);
        }
    }

    private KeyStore getTruststore() throws KeystoreServiceException {
        KeystoreService ks = (KeystoreService)this.services.getService(ServiceType.KEYSTORE_SERVICE);
        KeyStore trustKeystore = null;
        trustKeystore = ks.getTruststoreForHttpClient();
        if (trustKeystore == null) {
            trustKeystore = ks.getKeystoreForGateway();
        }
        return trustKeystore;
    }

    private ClientEndpointConfig getClientEndpointConfig(final ServletUpgradeRequest req) {
        return ClientEndpointConfig.Builder.create().configurator(new ClientEndpointConfig.Configurator(){

            public void beforeRequest(Map<String, List<String>> headers) {
                req.getHeaders().forEach(headers::putIfAbsent);
                try {
                    URI backendURL = new URI(GatewayWebsocketHandler.this.getMatchedBackendURL(req.getRequestURI()));
                    headers.put("Host", Arrays.asList(backendURL.getHost() + ":" + backendURL.getPort()));
                }
                catch (URISyntaxException e) {
                    LOG.onError(String.format(Locale.ROOT, "Error getting backend url, this could cause 'Host does not match SNI' exception. Cause: ", e.toString()));
                }
            }
        }).build();
    }

    protected synchronized String getMatchedBackendURL(URI requestURI) {
        String[] pathInfo;
        String path = requestURI.getRawPath();
        String query = requestURI.getRawQuery();
        ServiceRegistry serviceRegistryService = (ServiceRegistry)this.services.getService(ServiceType.SERVICE_REGISTRY_SERVICE);
        ServiceDefinitionRegistry serviceDefinitionService = (ServiceDefinitionRegistry)this.services.getService(ServiceType.SERVICE_DEFINITION_REGISTRY);
        ServiceDefEntry entry = serviceDefinitionService.getMatchingService((pathInfo = path.split(REGEX_SPLIT_CONTEXT))[1]);
        if (entry == null) {
            throw new RuntimeException(String.format(Locale.ROOT, "Cannot find service for the given path: %s", path));
        }
        String[] pathService = path.split(REGEX_SPLIT_SERVICE_PATH);
        String backendURL = GatewayWebsocketHandler.urlFromServiceDefinition(serviceRegistryService, entry, path);
        LOG.debugLog("Url obtained from services definition: " + backendURL);
        StringBuilder backend = new StringBuilder();
        try {
            if (StringUtils.containsAny((CharSequence)backendURL, (CharSequence[])new CharSequence[]{WEBSOCKET_PROTOCOL_STRING, SECURE_WEBSOCKET_PROTOCOL_STRING})) {
                LOG.debugLog("ws or wss protocol found in service url");
                URI serviceUri = new URI(backendURL);
                backend.append(serviceUri);
                String pathSuffix = this.generateUrlSuffix(backend.toString(), pathService);
                backend.append(pathSuffix);
            } else if (StringUtils.containsAny((CharSequence)requestURI.toString(), (CharSequence[])new CharSequence[]{WEBSOCKET_PROTOCOL_STRING, SECURE_WEBSOCKET_PROTOCOL_STRING})) {
                LOG.debugLog("ws or wss protocol found in request url");
                URL serviceUrl = new URL(backendURL);
                String protocol = serviceUrl.getProtocol().equals("https") ? "wss" : "ws";
                backend.append(protocol).append("://");
                backend.append(serviceUrl.getHost()).append(':');
                backend.append(serviceUrl.getPort()).append('/');
                backend.append(serviceUrl.getPath());
                String pathSuffix = this.generateUrlSuffix(backend.toString(), pathService);
                backend.append(pathSuffix);
            } else {
                LOG.debugLog("ws or wss protocol not found in service url or request url");
                URL serviceUrl = new URL(backendURL);
                String protocol = serviceUrl.getProtocol().equals("ws") || serviceUrl.getProtocol().equals("wss") ? serviceUrl.getProtocol() : "ws";
                backend.append(protocol).append("://");
                backend.append(serviceUrl.getHost()).append(':');
                backend.append(serviceUrl.getPort()).append('/');
                backend.append(serviceUrl.getPath());
            }
            if (!StringUtils.isBlank((CharSequence)query)) {
                backend.append('?').append(query);
            }
            backendURL = backend.toString();
        }
        catch (MalformedURLException e) {
            LOG.badUrlError(e);
            throw new RuntimeException(e.toString());
        }
        catch (Exception e1) {
            LOG.failedCreatingWebSocket(e1);
            throw new RuntimeException(e1.toString());
        }
        return backendURL;
    }

    private static String urlFromServiceDefinition(ServiceRegistry serviceRegistry, ServiceDefEntry entry, String path) {
        String[] contexts = path.split("/");
        return serviceRegistry.lookupServiceURL(contexts[2], entry.getName().toUpperCase(Locale.ROOT));
    }

    private String generateUrlSuffix(String backendPart, String[] pathService) {
        if (!StringUtils.endsWith((CharSequence)backendPart, (CharSequence)"/ws") && pathService.length > 0 && pathService[1] != null) {
            String newPathSuffix = pathService[1];
            if (backendPart.endsWith("/") && pathService[1].startsWith("/")) {
                newPathSuffix = pathService[1].substring(1);
            }
            return newPathSuffix;
        }
        return "";
    }
}

