#!/usr/bin/env python
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# Contributors: Guillaume Destuynder <gdestuynder@mozilla.com>

import argparse
import array
from datetime import datetime
import json
import logging
import os
import requests
import socket
import secrets
import subprocess
import sys
import time
import webbrowser
import yaml

def setup_logging(stream=sys.stdout, level=logging.DEBUG):
    """
    Setup app logging
    """
    formatstr="[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s"
    logging.basicConfig(format=formatstr, datefmt="%H:%M:%S", stream=stream)
    logger = logging.getLogger(__name__)
    logger.setLevel(level)

    # Enable this to debug the requests module
#    import http.client as http_client
#    http_client.HTTPConnection.debuglevel = 1
#    requests_log = logging.getLogger("requests.packages.urllib3")
#    requests_log.setLevel(logging.DEBUG)
#    requests_log.propagate = True

    return logger

class DotDict(dict):
    """
    dict.item notation for dict()'s
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = DotDict(value)
            self[key] = value

class Session():
    """
    The Session class contains all necessary session data for bcorp,
    such as CLI token, access proxy token, etc.
    """
    _CLI_TOKEN_EXP_DELTA = 2538000 # 30 days
    _CLI_TOKEN_LEN = 48 # 48 bits (~64 chars)
    _API_REQUEST_TIMEOUT = 60 # 1 minute
    _API_REQUEST_DELAY = 1 # Check API every 1 second
    _API_SESSION_NAME = 'session'

    def __init__(self, config, args):
        self.tokens = DotDict({
            'cli': {'value': '', 'exp': None},
            'access_proxy': {'value': '', 'exp': None},
            })
        self.config = config
        self.ssh_host = args.ssh_host
        self.ssh_port = args.ssh_port
        self.ssh_user = args.ssh_user
        self.ssh_key = os.path.expanduser(self.config.openssh.ssh_key_path)
        self.cache_path = os.path.expanduser(self.config.openssh.cache)

        # Load cache
        try:
            self._load()
        except FileNotFoundError:
            logger.debug('No cache found.')
            pass

    def load_credentials(self):
        # Check if we still have any valid credentials
        creds = self._get_ssh_credentials()
        if (creds.expiration):
            self._load_ssh_key(creds.expiration)
            return True
        return False

    def request_new_credentials(self):
        # Check if we have or need a CLI token
        self._verify_cli_token()
        # Check if we have or need an Access Proxy token
        self._verify_access_proxy_token()
        creds = self._get_ssh_credentials()
        del(creds)
        return True

    def _save(self):
        with open(os.open(self.cache_path, os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            json.dump(self.tokens, fd)

    def _load(self):
        with open(self.cache_path, 'r') as fd:
            try:
                d = DotDict(json.load(fd))
                self.tokens = d
            except json.decoder.JSONDecodeError:
                logger.error('Cache appears corrupted. Starting with new values.')

    def _get_cookie_jar(self):
        """
        Returns a RequestsCookieJar for the access proxy API
        """
        jar = requests.cookies.RequestsCookieJar()
        jar.set(self._API_SESSION_NAME, self.tokens.access_proxy.value)
        return jar

    def _verify_cli_token(self):
        if self.tokens.cli.exp == None or self.tokens.cli.exp >= time.time() + self._CLI_TOKEN_EXP_DELTA:
            logger.debug('Invalid CLI token, re-generating')
            self.tokens.cli.value = secrets.token_urlsafe(self._CLI_TOKEN_LEN)
            self.tokens.cli.exp = time.time()+ self._CLI_TOKEN_EXP_DELTA
            self._save()
            logger.debug('Generated new CLI token: {}'.format(self.tokens.cli.value))
        else:
            logger.debug('Re-using CLI token: {}'.format(self.tokens.cli.value))

    def _verify_access_proxy_token(self):
        if not self._check_api_authenticated():
            self._request_access_proxy_token()
            timeout = time.time()+self._API_REQUEST_TIMEOUT
            while True:
                logger.debug('Polling proxy API for another {} seconds until time out'.format((timeout-time.time())))
                if self._poll_proxy_api():
                    if self._check_api_authenticated():
                        logger.debug('All tokens are valid')
                        self.agent_pid = None
                        return
                else:
                    time.sleep(self._API_REQUEST_DELAY)
                if time.time() >= timeout:
                    logger.warning("connect to access proxy host API {}: Connection timed out. Have you authenticated in the "
                          "web browser window?".format(self.config.proxy_url))
                    sys.exit(127)
                    return
        else:
            logger.debug('Re-using access proxy token: {}'.format(self.tokens.access_proxy.value))

    def _poll_proxy_api(self):
        """
        Check if we got a proxy api reply
        """
        # Double / is required in order to access the path that is not protected by the reverse-proxy
        r = requests.get('{}/api/session?cli_token={}'.format(self.config.proxy_url, self.tokens.cli.value))
        # 202 accepted - means CLI token is not yet authenticated interactively by the user
        if r.status_code == 202:
            return False
        elif r.status_code == 200:
            try:
                auth = r.json().get('cli_token_authenticated')
                logger.debug('cli_token_authenticated is {}'.format(str(auth)))
            except:
                logger.debug('JSON Decoding failed - HTTP status code: {}, body: {}'.format(r.status_code,
                             r.text[0:100]))
                return False
            if not auth:
                return False
            self.tokens.access_proxy.value = r.json().get('ap_session')
            logger.debug('Retrieved new access proxy token: {}'.format(self.tokens.access_proxy.value))
            self._save()
            return True
        else:
            logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            return False

    def _request_access_proxy_token(self):
        """
        Requests an Access Proxy token from the .. Access Proxy
        """
        parameters = "?type=ssh&host={host}&user={user}&port={port}&cli_token={cli}".format(host = self.ssh_host,
                                                                                            port = self.ssh_port,
                                                                                            user = self.ssh_user,
                                                                                            cli= self.tokens.cli.value)
        logging.info("If no browser window was opened, please manually authenticate to the "
              "access proxy: {}{}".format(self.config.proxy_url, parameters))
        webbrowser.open(self.config.proxy_url+parameters, new=0, autoraise=True)

    def _check_api_authenticated(self, r=None):
        """
        Check we got an access proxy session cookie.
        If not, attempt to acquire a new one.
        If all fails, connection to SSH will fail.
        """
        r_src_self = False

        if not r:
            r_src_self = True
            headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:57.0) Gecko/20100101 Firefox/57.0'}
            r = requests.get('{}/api/ping'.format(self.config.proxy_url), cookies=self._get_cookie_jar(), headers=headers)
            if r.status_code != 200:
                logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))

        cookie = r.cookies.get(self._API_SESSION_NAME)
        if cookie:
            #FIXME set expiration for our records
            #Current access proxy does not support this though, hence the FIXME
            pass

        try:
            tmp = r.json()
        except json.decoder.JSONDecodeError:
            logger.debug('JSON decoding failed. HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            return False

        if r_src_self and not tmp.get('PONG'):
            logger.debug('access_proxy token is invalid')
            return False

        # Since we don't get a cookie back, and we got a specific API request forwarded to the function:
        # if we could decode any JSON we consider this is valid
        logger.debug('access_proxy token is still valid')
        return True

    def _get_ssh_certificate_expiration(self):
        """
        Get certificate expiration from the certificate file
        Currently done using ssh-keygen though a native implementation would be better
        Returns a string that represent the amount of time until the certificate signature expires or None if invalid
        The time format is OpenSSH's time format
        """
        # See `man sshd_config` `TIME FORMATS` for time syntax
        # Output format contains a line such as:
        #        Valid: from 2017-09-19T16:10:21 to 2017-09-19T17:20:21
        key_validity = '15m'

        try:
            stdout = subprocess.check_output(['ssh-keygen', '-L', '-f', self.ssh_key+'-cert.pub'])
        except subprocess.CalledProcessError as e:
            logger.debug('Could not read previous SSH key: {}'.format(e))
            stdout = ''
            return None
        try:
            for tmp in stdout.decode('ascii').split('\n'):
                if tmp.find('Valid: ') != -1:
                    end_time = tmp.split(' to ')[-1]
                    # Expiration is how many seconds from now to end_time
                    dt = datetime.strptime(end_time, '%Y-%m-%dT%H:%M:%S')
                    exp = int(dt.timestamp())
                    delta = exp - int(time.time())
                    if delta < 0:
                        logger.error('SSH key is already expired. You need a new key. '
                                     'Time remaining was: {} seconds'.format(str(delta)))
                        return None
                    else:
                        key_validity = '{}s'.format(str(delta))
                        logger.debug('SSH key is valid for {} seconds'.format(str(delta)))
                    break
        except:
            # Use default if parsing fails in any way
            logging.debug('Failed to read SSH key expiration time, using default: {}'.format(key_validity))
            return key_validity
        return key_validity

    def _get_ssh_credentials(self):
        """
        Retrieves the certificate and private key from the access proxy.
        """
        creds = DotDict({'private_key': None, 'public_key': None, 'certificate': None, 'expiration': None})

        # Are local credentials still valid?
        creds.expiration = self._get_ssh_certificate_expiration()
        if (creds.expiration):
            logger.debug('Returning cached SSH credentials (valid for another {})'.format(creds.expiration))
            return creds

        headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:57.0) Gecko/20100101 Firefox/57.0'}
        r = requests.get('{}/api/ssh?cli_token={}'.format(self.config.proxy_url, self.tokens.cli.value),
                         cookies=self._get_cookie_jar(), headers=headers)
        if not self._check_api_authenticated(r):
            return creds

        if r.status_code != 200:
            logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            logger.error('Failed to get SSH credentials, permission will be denied')
            return creds
        try:
            tmp = r.json()
        except json.decoder.JSONDecodeError:
            logger.debug('JSON decoding failed. HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            logger.error('Failed to get SSH credentials, permission will be denied')
            return creds

        try:
            creds.private_key = tmp['private_key']
            creds.public_key = tmp['public_key']
            creds.certificate = tmp['certificate']
        except KeyError:
            logger.error('Could not interpret access proxy data')
            return creds

        logger.debug('Got new SSH credentials')
        self._save_ssh_creds(creds)
        return creds

    def _load_ssh_key(self, key_validity):
        # We need to first unload as we cannot destroy ourselves at the end of the session
        # Because the ssh client kills us first, and we want to make sure we start clean instead of piling up
        # keys in the agent which can cause issues (for ex ssh will try the first 5 keys then the sshd will refuse the
        # client because it has tried too many times!)
        # Bonus: key stay loaded for the user if it's still valid as well
        # Note: we also abuse subprocess.PIPE here, but we know that the buffer is too small to cause deadlocks
        logger.debug('Unloading previous key if any')
        subprocess.call(['ssh-add', '-d', self.ssh_key], stderr=subprocess.PIPE)
        logger.debug('Loading key in ssh-agent')
        return subprocess.call(['ssh-add', '-t', key_validity, self.ssh_key], stderr=subprocess.PIPE)

    def _save_ssh_creds(self, creds):
        if (creds.private_key is None):
            logger.error("Saving SSH Credentials failed: No credentials received.")
            return
        logger.debug('Saving SSH credentials to {}'.format(self.ssh_key))

        with open(os.open(self.ssh_key, os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.private_key)
        with open(os.open(self.ssh_key+'.pub', os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.public_key)
        with open(os.open(self.ssh_key+'-cert.pub', os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.certificate)
        del(creds)

def usage():
    logging.info("""USAGE: {} remote_hostname:remote_port:remote_user""".format(sys.argv[0]))
    sys.exit(1)

def main(args, config):
    global logger

    if not args.verbose:
        level = logging.INFO
    else:
        level = logging.DEBUG
    logger = setup_logging(stream=sys.stderr, level=level)


    # Get arguments from OpenSSH's ProxyCommand
    # this program is usually called as such:
    # %h: remote hostname, %p: remote port, %r: remote user name
    # ssh -oProxyCommand='/usr/bin/bssh.py %h:%p:%r' kang@myhost.com
    try:
        (ssh_host, ssh_port, ssh_user) = args.moduleopts.split(':')
        args = DotDict({'ssh_host': ssh_host, 'ssh_port': ssh_port, 'ssh_user': ssh_user})
    except NameError:
        usage()

    # Load session (or create new one)
    ses = Session(config, args)
    if not ses.load_credentials():
        logger.debug('No existing credentials, requesting new ones from access proxy')
        ses.request_new_credentials()
        if not ses.load_credentials():
            logger.debug('Failed to request new credentials')

    logger.debug('Tunneling SSH traffic...')
    # Pass to SSH
    # XXX FIXME figure out ProxyUseFdPass
    #s = socket.create_connection((sys.argv[1], int(sys.argv[2])))
    #fds = array.array("i", [s.fileno()])
    #ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
    #socket.socket(fileno = 1).sendmsg([b'\0'], ancdata)
    #See also https://lists.mindrot.org/pipermail/openssh-unix-dev/2013-June/031483.html
    # https://github.com/solrex/netcat/blob/master/netcat.c#L1246
    # http://www.gabriel.urdhr.fr/2016/08/07/openssh-proxyusefdpass/
    s = socket.create_connection((args.ssh_host, int(args.ssh_port)))
    import select
    import fcntl
    fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, os.O_NONBLOCK)
    while True:
        (r,w,x)= select.select([sys.stdin,s.fileno()], [], [])
        if s.fileno() in r:
            try:
                sys.stdout.buffer.write(s.recv(8192))
                sys.stdout.flush()
            except BrokenPipeError:
                logger.debug('ProxyCommand closed stdin')
                break
        if sys.stdin in r:
            try:
                s.send(sys.stdin.buffer.read(8192))
            except BrokenPipeError:
                logger.debug('Remote host closed connection')
                break
    s.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--type', required=True, choices=['ssh', 'sts'], help='Select type of credentials to request')
    parser.add_argument('-v', '--verbose', action='store_true', help='Enable debugging / verbose messages')
    parser.add_argument('-c', '--config',  help='Specify a configuration file')
    parser.add_argument('moduleopts', help='Module specific options')
    args = parser.parse_args()

    with open(args.config or 'bcorp.yml') as fd:
        config = DotDict(yaml.load(fd))
        # Ensure we have no double / at the end of the URL as this confuses reverse proxies
        config.proxy_url = config.proxy_url.rstrip('/')

    main(args, config)
