#!/usr/bin/env python
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# Contributors: Guillaume Destuynder <gdestuynder@mozilla.com>

import argparse
import array
from datetime import datetime
import json
import logging
import os
import requests
import socket
try:
    BrokenPipeError2 = BrokenPipeError
except NameError:
    BrokenPipeError2 = socket.error
try:
    FileNotFoundError2 = FileNotFoundError
except NameError:
    FileNotFoundError = IOError
import subprocess
import sys
import time
import webbrowser
import yaml

# Python2 compat
if hasattr(sys.stdin, "buffer"):
    stdinpipe = sys.stdin.buffer
    stdoutpipe = sys.stdout.buffer
else:
    stdinpipe = sys.stdin
    stdoutpipe = sys.stdout

def setup_logging(stream=sys.stdout, level=logging.DEBUG):
    """
    Setup app logging
    """
    formatstr="[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s"
    logging.basicConfig(format=formatstr, datefmt="%H:%M:%S", stream=stream)
    logger = logging.getLogger(__name__)
    logger.setLevel(level)

    # Enable this to debug the requests module
#    import http.client as http_client
#    http_client.HTTPConnection.debuglevel = 1
#    requests_log = logging.getLogger("requests.packages.urllib3")
#    requests_log.setLevel(logging.DEBUG)
#    requests_log.propagate = True

    return logger

class DotDict(dict):
    """
    dict.item notation for dict()'s
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = DotDict(value)
            self[key] = value

class Session():
    """
    The Session class contains all necessary session data for bcorp,
    such as CLI token, access proxy token, etc.
    """
    _CLI_TOKEN_EXP_DELTA = 2538000 # 30 days
    _CLI_TOKEN_LEN = 48 # 48 bits (~64 chars)
    _API_REQUEST_TIMEOUT = 60 # 1 minute
    _API_REQUEST_DELAY = 1 # Check API every 1 second
    _API_SESSION_NAME = 'session'

    def __init__(self, config, args):
        self.tokens = DotDict({
            'cli': {'value': '', 'exp': None},
            'access_proxy': {'value': '', 'exp': None},
            })
        self.config = config
        self.ssh_host = args.ssh_host
        self.ssh_port = args.ssh_port
        self.ssh_user = args.ssh_user
        self.ssh_key = os.path.expanduser(self.config.openssh.ssh_key_path)
        self.cache_path = os.path.expanduser(self.config.openssh.cache)

        # Load cache
        try:
            self._load()
        except FileNotFoundError:
            logger.debug('No cache found.')
            pass

    def load_credentials(self):
        # Check if we still have any valid credentials
        creds = self._get_ssh_credentials()
        if (creds.expiration):
            self._load_ssh_key(creds.expiration)
            return True
        return False

    def request_new_credentials(self):
        # Check if we have or need a CLI token
        self._verify_cli_token()
        # Check if we have or need an Access Proxy token
        self._verify_access_proxy_token()
        creds = self._get_ssh_credentials()
        del(creds)
        return True

    def _save(self):
        with open(os.open(self.cache_path, os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            json.dump(self.tokens, fd)

    def _load(self):
        with open(self.cache_path, 'r') as fd:
            try:
                d = DotDict(json.load(fd))
                self.tokens = d
            except json.decoder.JSONDecodeError:
                logger.error('Cache appears corrupted. Starting with new values.')

    def _get_cookie_jar(self):
        """
        Returns a RequestsCookieJar for the access proxy API
        """
        jar = requests.cookies.RequestsCookieJar()
        jar.set(self._API_SESSION_NAME, self.tokens.access_proxy.value)
        return jar

    def _gen_token(length=8, charset="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()"):
        """
        Returns a random token, to be used as session identifier. Must be using 'true' random as this is
        security-sensitive. Session identifiers are credential-equivalents.
        """
        random_bytes = os.urandom(length)
        len_charset = len(charset)
        indices = [int(len_charset * (ord(byte) / 256.0)) for byte in random_bytes]
        return "".join([charset[index] for index in indices])

    def _verify_cli_token(self):
        if self.tokens.cli.exp == None or self.tokens.cli.exp >= time.time() + self._CLI_TOKEN_EXP_DELTA:
            logger.debug('Invalid CLI token, re-generating')
            self.tokens.cli.value = self._gen_token(self._CLI_TOKEN_LEN)
            self.tokens.cli.exp = time.time()+ self._CLI_TOKEN_EXP_DELTA
            self._save()
            logger.debug('Generated new CLI token: {}'.format(self.tokens.cli.value))
        else:
            logger.debug('Re-using CLI token: {}'.format(self.tokens.cli.value))

    def _verify_access_proxy_token(self):
        if not self._check_api_authenticated():
            self._request_access_proxy_token()
            timeout = time.time()+self._API_REQUEST_TIMEOUT
            while True:
                logger.debug('Polling proxy API for another {} seconds until time out'.format((timeout-time.time())))
                if self._poll_proxy_api():
                    if self._check_api_authenticated():
                        logger.debug('All tokens are valid')
                        self.agent_pid = None
                        return
                else:
                    time.sleep(self._API_REQUEST_DELAY)
                if time.time() >= timeout:
                    logger.warning("connect to access proxy host API {}: Connection timed out. Have you authenticated in the "
                          "web browser window?".format(self.config.proxy_url))
                    sys.exit(127)
                    return
        else:
            logger.debug('Re-using access proxy token: {}'.format(self.tokens.access_proxy.value))

    def _poll_proxy_api(self):
        """
        Check if we got a proxy api reply
        """
        # Double / is required in order to access the path that is not protected by the reverse-proxy
        r = requests.get('{}/api/session?cli_token={}'.format(self.config.proxy_url, self.tokens.cli.value))
        # 202 accepted - means CLI token is not yet authenticated interactively by the user
        if r.status_code == 202:
            return False
        elif r.status_code == 200:
            try:
                auth = r.json().get('cli_token_authenticated')
                logger.debug('cli_token_authenticated is {}'.format(str(auth)))
            except:
                logger.debug('JSON Decoding failed - HTTP status code: {}, body: {}'.format(r.status_code,
                             r.text[0:100]))
                return False
            if not auth:
                return False
            self.tokens.access_proxy.value = r.json().get('ap_session')
            logger.debug('Retrieved new access proxy token: {}'.format(self.tokens.access_proxy.value))
            self._save()
            return True
        else:
            logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            return False

    def _request_access_proxy_token(self):
        """
        Requests an Access Proxy token from the .. Access Proxy
        """
        parameters = "?type=ssh&host={host}&user={user}&port={port}&cli_token={cli}".format(host = self.ssh_host,
                                                                                            port = self.ssh_port,
                                                                                            user = self.ssh_user,
                                                                                            cli= self.tokens.cli.value)
        logging.info("If no browser window was opened, please manually authenticate to the "
              "access proxy: {}{}".format(self.config.proxy_url, parameters))
        webbrowser.open(self.config.proxy_url+parameters, new=0, autoraise=True)

    def _check_api_authenticated(self, r=None):
        """
        Check we got an access proxy session cookie.
        If not, attempt to acquire a new one.
        If all fails, connection to SSH will fail.
        """
        r_src_self = False

        if not r:
            r_src_self = True
            headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:57.0) Gecko/20100101 Firefox/57.0'}
            r = requests.get('{}/api/ping'.format(self.config.proxy_url), cookies=self._get_cookie_jar(), headers=headers)
            if r.status_code != 200:
                logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))

        cookie = r.cookies.get(self._API_SESSION_NAME)
        if cookie:
            #FIXME set expiration for our records
            #Current access proxy does not support this though, hence the FIXME
            pass

        try:
            tmp = r.json()
        except json.decoder.JSONDecodeError:
            logger.debug('JSON decoding failed. HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            return False

        if r_src_self and not tmp.get('PONG'):
            logger.debug('access_proxy token is invalid')
            return False

        # Since we don't get a cookie back, and we got a specific API request forwarded to the function:
        # if we could decode any JSON we consider this is valid
        logger.debug('access_proxy token is still valid')
        return True

    def _get_ssh_certificate_expiration(self):
        """
        Get certificate expiration from the certificate file
        Currently done using ssh-keygen though a native implementation would be better
        Returns a string that represent the amount of time until the certificate signature expires or None if invalid
        The time format is OpenSSH's time format
        """
        # See `man sshd_config` `TIME FORMATS` for time syntax
        # Output format contains a line such as:
        #        Valid: from 2017-09-19T16:10:21 to 2017-09-19T17:20:21
        key_validity = '15m'

        try:
            stdout = subprocess.check_output(['ssh-keygen', '-L', '-f', self.ssh_key+'-cert.pub'])
        except subprocess.CalledProcessError as e:
            logger.debug('Could not read previous SSH key: {}'.format(e))
            stdout = ''
            return None
        try:
            for tmp in stdout.decode('ascii').split('\n'):
                if tmp.find('Valid: ') != -1:
                    end_time = tmp.split(' to ')[-1]
                    # Expiration is how many seconds from now to end_time
                    dt = datetime.strptime(end_time, '%Y-%m-%dT%H:%M:%S')
                    exp = int(dt.timestamp())
                    delta = exp - int(time.time())
                    if delta < 0:
                        logger.error('SSH key is already expired. You need a new key. '
                                     'Time remaining was: {} seconds'.format(str(delta)))
                        return None
                    else:
                        key_validity = '{}s'.format(str(delta))
                        logger.debug('SSH key is valid for {} seconds'.format(str(delta)))
                    break
        except:
            # Use default if parsing fails in any way
            logging.debug('Failed to read SSH key expiration time, using default: {}'.format(key_validity))
            return key_validity
        return key_validity

    def _get_ssh_credentials(self):
        """
        Retrieves the certificate and private key from the access proxy.
        """
        creds = DotDict({'private_key': None, 'public_key': None, 'certificate': None, 'expiration': None})

        # Are local credentials still valid?
        creds.expiration = self._get_ssh_certificate_expiration()
        if (creds.expiration):
            logger.debug('Returning cached SSH credentials (valid for another {})'.format(creds.expiration))
            return creds

        headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:57.0) Gecko/20100101 Firefox/57.0'}
        r = requests.get('{}/api/ssh?cli_token={}'.format(self.config.proxy_url, self.tokens.cli.value),
                         cookies=self._get_cookie_jar(), headers=headers)
        if not self._check_api_authenticated(r):
            return creds

        if r.status_code != 200:
            logger.debug('HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            logger.error('Failed to get SSH credentials, permission will be denied')
            return creds
        try:
            tmp = r.json()
        except json.decoder.JSONDecodeError:
            logger.debug('JSON decoding failed. HTTP status code: {}, body: {}'.format(r.status_code, r.text[0:100]))
            logger.error('Failed to get SSH credentials, permission will be denied')
            return creds

        try:
            creds.private_key = tmp['private_key']
            creds.public_key = tmp['public_key']
            creds.certificate = tmp['certificate']
        except KeyError:
            logger.error('Could not interpret access proxy data')
            return creds

        logger.debug('Got new SSH credentials')
        self._save_ssh_creds(creds)
        return creds

    def _load_ssh_key(self, key_validity):
        # We need to first unload as we cannot destroy ourselves at the end of the session
        # Because the ssh client kills us first, and we want to make sure we start clean instead of piling up
        # keys in the agent which can cause issues (for ex ssh will try the first 5 keys then the sshd will refuse the
        # client because it has tried too many times!)
        # Bonus: key stay loaded for the user if it's still valid as well
        # Note: we also abuse subprocess.PIPE here, but we know that the buffer is too small to cause deadlocks
        logger.debug('Unloading previous key if any')
        subprocess.call(['ssh-add', '-d', self.ssh_key], stderr=subprocess.PIPE)
        logger.debug('Loading key in ssh-agent')
        return subprocess.call(['ssh-add', '-t', key_validity, self.ssh_key], stderr=subprocess.PIPE)

    def _save_ssh_creds(self, creds):
        if (creds.private_key is None):
            logger.error("Saving SSH Credentials failed: No credentials received.")
            return
        logger.debug('Saving SSH credentials to {}'.format(self.ssh_key))

        with open(os.open(self.ssh_key, os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.private_key)
        with open(os.open(self.ssh_key+'.pub', os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.public_key)
        with open(os.open(self.ssh_key+'-cert.pub', os.O_WRONLY|os.O_CREAT, mode=0o600), 'w') as fd:
            fd.write(creds.certificate)
        del(creds)

def usage():
    logging.info("""USAGE: {} remote_hostname:remote_port:remote_user""".format(sys.argv[0]))
    sys.exit(1)

def main(args, config):
    global logger

    if not args.verbose:
        level = logging.INFO
    else:
        level = logging.DEBUG
    logger = setup_logging(stream=sys.stderr, level=level)


    # Get arguments from OpenSSH's ProxyCommand
    # this program is usually called as such:
    # %h: remote hostname, %p: remote port, %r: remote user name
    # ssh -oProxyCommand='/usr/bin/bssh.py %h:%p:%r' kang@myhost.com
    try:
        (ssh_host, ssh_port, ssh_user) = args.moduleopts.split(':')
        args = DotDict({'ssh_host': ssh_host, 'ssh_port': ssh_port, 'ssh_user': ssh_user})
    except NameError:
        usage()

    # Load session (or create new one)
    ses = Session(config, args)
    if not ses.load_credentials():
        logger.debug('No existing credentials, requesting new ones from access proxy')
        ses.request_new_credentials()
        if not ses.load_credentials():
            logger.debug('Failed to request new credentials')

    logger.debug('Tunneling SSH traffic...')
    # Pass to SSH
    # XXX FIXME figure out ProxyUseFdPass
    #s = socket.create_connection((sys.argv[1], int(sys.argv[2])))
    #fds = array.array("i", [s.fileno()])
    #ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
    #socket.socket(fileno = 1).sendmsg([b'\0'], ancdata)
    #See also https://lists.mindrot.org/pipermail/openssh-unix-dev/2013-June/031483.html
    # https://github.com/solrex/netcat/blob/master/netcat.c#L1246
    # http://www.gabriel.urdhr.fr/2016/08/07/openssh-proxyusefdpass/
    s = socket.create_connection((args.ssh_host, int(args.ssh_port)))
    import select
    import fcntl
    fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, os.O_NONBLOCK)

    while True:
        (r,w,x)= select.select([sys.stdin,s.fileno()], [], [])
        if s.fileno() in r:
            try:
                stdoutpipe.write(s.recv(8192))
                stdoutpipe.flush()
            except BrokenPipeError2:
                logger.debug('ProxyCommand closed stdin')
                break
        if sys.stdin in r:
            try:
                s.send(stdinpipe.read(8192))
            except BrokenPipeError2:
                logger.debug('Remote host closed connection')
                break
    s.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--type', required=True, choices=['ssh', 'sts'], help='Select type of credentials to request')
    parser.add_argument('-v', '--verbose', action='store_true', help='Enable debugging / verbose messages')
    parser.add_argument('-c', '--config',  help='Specify a configuration file')
    parser.add_argument('moduleopts', help='Module specific options')
    args = parser.parse_args()

    with open(args.config or 'bcorp.yml') as fd:
        config = DotDict(yaml.load(fd))
        # Ensure we have no double / at the end of the URL as this confuses reverse proxies
        config.proxy_url = config.proxy_url.rstrip('/')

    main(args, config)
