#!/usr/bin/env python3

import argparse
import atexit
import logging
import os
import pwd
import socket
import subprocess
import sys
import textwrap
import threading
import time
from collections import OrderedDict

from fuse import FUSE

from basefs import loop, utils
from basefs.fs import FileSystem
from basefs.keys import Key
from basefs.logs import Log
from basefs.messages import SerfClient
from basefs.state import BlockState
from basefs.views import View


pw_dir = pwd.getpwuid(os.getuid()).pw_dir
default_keypath = os.path.join(pw_dir, '.basefs', 'id_ec')
mount_info = utils.get_mount_info()
default_logpath = mount_info and mount_info.logpath or os.path.join(pw_dir, '.basefs', 'log')


def send_command(cmd, *args):
    if mount_info is None:
        sys.stderr.write("Error: which ?\n")
        sys.exit(2)
    data = ''
    for part in utils.netcat('127.0.0.1', mount_info.port, ' '.join(('c' + cmd,) + args).encode()):
        data += part
    return data


def file_exists(parser, arg, name='The', exec=None):
    if not os.path.exists(arg):
        parser.error("%s file %s does not exist" % (name, arg))
    elif not os.path.isfile(arg):
        parser.error("%s path %s is not a file" % (name, arg))
    elif exec is True and not os.access(handler, os.X_OK):
        parser.error("%s %s has no execution permissions\n" % (name, arg))
    else:
        return arg


def dir_exists(parser, arg, name='The'):
    if not os.path.exists(arg):
        parser.error("%s dir %s does not exist" % (name, arg))
    elif not os.path.isdir(arg):
        parser.error("%s path %s is not a directory" % (name, arg))
    else:
        return arg


def fingerprint(parser, arg):
    if arg.count(':') != 15:
        parser.error("%s %s not a valid fingerprint" % (name, arg))
    else:
        try:
            return log.keys[args.grant_key]
        except KeyError:
            parser.error("%s %s fingerprint not found." % (name, arg))


def key(parser, arg):
    if file_exists(parser, arg, name='keypath'):
        try:
            return Key.load(arg)
        except Exception as exc:
            parser.error("%s '%s' %s\n" % (name, arg, str(exc)))
    elif arg.count(':') == 15:
        return fingerprint(parser, arg)
    parser.error("%s %s not a valid key fingerprint nor key path." % (name, arg))


mount_parser = argparse.ArgumentParser(
    description='Mount an existing filesystem',
    prog='basefs mount')

def mount():
    mount_parser.add_argument('logpath', nargs='?', default=default_logpath,
        type=lambda v: file_exists(mount_parser, v, name='logpath'))
    mount_parser.add_argument('mountpoint',
        type=lambda v: dir_exists(mount_parser, v, name='mountpoint'))
    mount_parser.add_argument('-k', '--keys', dest='keypath',
        default=default_keypath,
        help='Path to the EC private key. %s by default. Use genkey for creating one.' % default_keypath,
        type=lambda v: file_exists(mount_parser, v, name='keypath'))
#    mount_parser.add_argument('--fs-handler', dest='handler', nargs=1,
#        help='Custom handler script for filesystem update notifications')
    mount_parser.add_argument('-p', '--port', dest='port', type=int, default=7372,
        help='Serf agent port (serf agent port port+1, sync server port+1), defaults to 7372 (7373, 7374)')
    mount_parser.add_argument('-n', '--hostname', dest='hostname', default=socket.gethostname(),
        help='Name of this node. Must be unique in the cluster.')
    mount_parser.add_argument('-d', '--debug', dest='debug', action='store_true',
        help='Enables debugging information.')
    mount_parser.add_argument('-s', '--single-node', dest='serf', action='store_false',
        help='Disables Serf agent (testing purposes).')
    
    args = mount_parser.parse_args()
    logpath = os.path.normpath(args.logpath)
    keypath = os.path.normpath(args.keypath)
    mountpoint = args.mountpoint
    rpc_port = args.port+1
    sync_port = args.port+2
    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
    if args.serf:
        # launch serf agent
        context = {
            'script': os.path.realpath(__file__),
            'log': logpath,
#            'handler': ('--fs-handler='+args.handler[0]) if args.handler else '',
            'debug': '-log-level=debug' if args.debug else '',
            'port': args.port,
            'rpc_port': rpc_port,
            'sync_port': sync_port,
            'hostname': args.hostname,
        }
        cmd = textwrap.dedent("""\
            serf agent \\
                -node %(hostname)s \\
                -bind 0.0.0.0:%(port)s \\
                -rpc-addr=127.0.0.1:%(rpc_port)s \\
                -replay %(debug)s \\
                -event-handler='{ echo -n "$SERF_USER_EVENT" && cat -; } | nc 127.0.0.1 %(sync_port)s'""") % context
#        sys.stdout.write(cmd + '\n')
#        serf_agent = threading.Thread(target=os.system, args=(cmd,))
#        serf_agent.start()
        serf_agent = subprocess.Popen(cmd, shell=True)
        atexit.register(lambda: serf_agent.kill())
        # Launch serf client
        logpath = os.path.normpath(logpath)
        log = Log(logpath)
        log.load()
        key = Key.load(keypath)
        view = View(log, key)
        view.build()
        blockstate = BlockState(log)
        serf = SerfClient(log, blockstate, port=rpc_port)
        cluster = view.get('/.cluster')
        members = [line.strip() for line in cluster.content.splitlines() if line.strip()]
        join_result = serf.join(members)
        if join_result.head[b'Error']:
            raise RuntimeError("Couldn't connect to serf cluster %s." % members)
        # Launch event loop
        handler = threading.Thread(target=loop.run, args=(view, serf, sync_port))
        handler.start()
    sys.stdout.write('Monting %s into %s\n' % (logpath, mountpoint))
    fsname = '%s:%i' % (logpath, sync_port)
    fs = FileSystem(view, serf)
    FUSE(fs, mountpoint, fsname=fsname, nothreads=True, foreground=True)


bootstrap_parser = argparse.ArgumentParser(
    description='Create a new self-contained filesystem',
    prog='basefs bootstrap')

def bootstrap():
    bootstrap_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath)
    bootstrap_parser.add_argument('-k', '--keys', dest='keypaths',
        default=default_keypath,
        help='Comma separated list of paths containing the root keys. %s by default.' % default_keypath)
    bootstrap_parser.add_argument('-i', '--ips', dest='ips', required=True,
        help='comma separated ips used as boostrapping nodes.')
    bootstrap_parser.add_argument('-f', '--force', dest='force', action='store_true',
        help='Rewrite log file if present')
    
    args = bootstrap_parser.parse_args()
    logpath = args.logpath
    if os.path.exists(logpath):
        if not args.force:
            sys.stderr.write("Error: logpath %s already exists and --force argument was not provided\n" % logpath)
            sys.exit(1)
        else:
            os.remove(logpath)
    keys = []
    for keypath in args.keypaths.split(','):
        if not os.path.isfile(keypath):
            sys.stderr.write("Error: bootsraping keypath %s does not exist.\n" % keypath)
            sys.exit(2)
        keys.append(Key.load(keypath))
    log = Log(logpath)
    log.bootstrap(keys, args.ips.split(','))
    sys.stdout.write('Created log file %s\n' % logpath)
    sys.exit(0)


handler_parser = argparse.ArgumentParser(
    description='Run as Serf handler',
    prog='basefs handler')

def handler():
    handler_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath,
        type=lambda v: file_exists(handler_parser, v, name='logpath'))
    handler_parser.add_argument('--fs-handler', dest='handler', nargs=1,
        help='custom handler, used for executing custom actions when a path is updated.',
        type=lambda v: file_exists(handler_parser, v, name='fs handler', exec=True))
    args = handler_parser.parse_args()
    handler = args.handler[0] if args.handler else None
    event = os.environ.get('SERF_EVENT')
    user_event = os.environ.get('SERF_USER_EVENT')
    
    if event == 'user' and user_event == 'logentry':
        log = Log(args.logpath)
        log.load()
        client = SerfClient(log)
        client.receive(sys.stdin.buffer.read())
    sys.exit()


genkey_parser = argparse.ArgumentParser(
    description='Generate a new EC private key',
    prog='basefs genkey')

def genkey():
    genkey_parser.add_argument('keypath', nargs='?',
        default=default_keypath,
        help='Path to the EC private key. %s by default.' % default_keypath)
    genkey_parser.add_argument('-f', '--force', dest='force', action='store_true',
        help='Rewrite key file if present.')
    args = genkey_parser.parse_args()
    keypath = args.keypath
    keydir = os.path.dirname(keypath)
    if not os.path.exists(keydir):
        if keypath == default_keypath:
            os.mkdir(keydir)
        else:
            sys.stderr.write("Error: %s keypath directory doesn't exist, create it first.\n" % keydir)
            sys.exit(2)
    elif not args.force and os.path.exists(keypath):
        sys.stderr.write('Error: %s key already exists, use --force to override it.\n' % keypath)
        sys.exit(2)
    key = Key.generate()
    key.save(keypath)
    sys.stdout.write("Generate EC key on %s\n" % keypath)
    sys.exit()


keys_parser = argparse.ArgumentParser(
    description='List keys and their directories',
    prog='basefs keys')

def keys():
    keys_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath,
        type=lambda v: file_exists(keys_parser, v, name='logpath'))
    keys_parser.add_argument('-p', '--path', dest='path', default='/',
        help='Base path.')
    keys_parser.add_argument('-d', '--by-dir', dest='by_dir', action='store_true',
        help='List keys by dir instead of by key.')
    args = keys_parser.parse_args()
    log = Log(args.logpath)
    log.load()
    view = View(log)
    view.build()
    keys = view.get_keys(path=args.path, by_dir=args.by_dir)
    for key, values in keys.items():
        sys.stdout.write(key + '\n')
        for value in values:
            sys.stdout.write('    ' + value + '\n')


log_parser = argparse.ArgumentParser(
    description='Show a log file using a tree representation',
    prog='basefs log')

def log():
    log_parser.add_argument('path', nargs='?', default=os.sep,
        help='Path to the basefs log file, uses / by default.')
    log_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath,
        type=lambda v: file_exists(log_parser, v, name='logpath'))
    log_parser.add_argument('-a', '--ascii', dest='ascii', action='store_true',
        help='use ASCII line drawing characters')
    log_parser.add_argument('-c', '--color', dest='color', action='store_true',
        help='use terminal coloring')
    args = log_parser.parse_args()
    log = Log(args.logpath)
    log.load()
    view = View(log)
    view.build()
    
    printed = False
    def print_tree(entry):
        tree = log.print_tree(entry=entry, view=view, color=args.color, ascii=args.ascii)
        if args.ascii:
            sys.stdout.buffer.write(tree.encode('ascii', errors='replace'))
        else:
            sys.stdout.write(tree)
        return True
    
    if mount_info:
        path = os.path.abspath(args.path)
        path = path.replace(mount_info.mountpoint, '')
        entry = log.find(path)
        if entry is None:
            entry = log.find(args.path)
        if entry:
            printed = print_tree(entry)
    else:
        entry = log.find(args.path)
        if entry:
            printed = print_tree(entry)
    if not printed:
        sys.stderr.write("Error: '%s' path does not exist on the log.\n" % args.path)
        sys.exit(2)


grant_parser = argparse.ArgumentParser(
    description='Grant key write permission',
    prog='basefs grant')

def grant():
#    grant_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(grant_parser, v, name='logpath'))
    grant_parser.add_argument('grantpath',
        help='Path where the permission should be granted.')
    grant_parser.add_argument('grantkey',
        help='Key fingerprint, if exists on lskeys, or path to a public key.',
        type=lambda v: key(grant_parser, v))
#    grant_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(grant_parser, v))
    args = grant_parser.parse_args()
    mount_info = utils.get_mount_info(os.path.abspath(args.grantpath))
    if mount_info is None:
        sys.stderr.write("Error: grantpath '%s' is not a basefs mountpoint subdir\n" % args.grantpath)
        sys.exit(2)
    path = os.path.relpath(args.grantpath, mount_info.mountpoint)
    if path.startswith('.'):
        path = path[1:]
    path = '/' + path
    context = {
        'path': path,
        'grant_key': args.grantkey.oneliner(),
    }
    grant_cmd = 'c GRANT %(path)s %(grant_key)s' % context
    response = utils.netcat('127.0.0.1', mount_info.port, grant_cmd)
    print(response)
    sys.exit()


revoke_parser = argparse.ArgumentParser(
    description='Revoke key write permission',
    prog='basefs revoke')

def revoke():
#    revoke_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(revoke_parser, v, name='logpath'))
    grant_parser.add_argument('revokekey',
        help='Key fingerprint, if exists on lskeys, or path to a public key.',
        type=lambda v: fingerprint(revoke_parser, v))
    revoke_parser.add_argument('revokepath', nargs='?', default='/',
        help='Path where the permission should be granted. Defaults to /.')
#    grant_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(revoke_parser, v))
    args = revoke_parser.parse_args()
    log = Log(args.logpath)
    view = View(log, args.key)
    view.revoke(args.revokepath, args.revokekey.fingerprint)
    sys.exit()


revert_parser = argparse.ArgumentParser(
    description="Revert object to previous state, 'log' command lists all revisions",
    prog='basefs revert')

def revert():
#    revert_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(revoke_parser, v, name='logpath'))
    revert_parser.add_argument('path',
        help='Path of the directory or file to revert')
    revert_parser.add_argument('hash',
        help="Hash of a previous revision, use 'basefs log path' for showing all revisions")
#    revert_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(revoke_parser, v))
    args = revert_parser.parse_args()
    log = Log(args.logpath)
    view = View(log, args.key)
    view.revert(args.path, args.hash)
    sys.exit()


blocks_parser = argparse.ArgumentParser(
    description="Block state",
    prog='basefs blocks')

def blocks():
    cmd = 'cBLOCKSTATE'
    while True:
        result = send_command('BLOCKSTATE')
        print(result)
        time.sleep(1)
    sys.exit()

members_parser = argparse.ArgumentParser(
    description="List cluster members",
    prog='basefs members')

def members():
    cmd = 'cMEMBERS'
    sys.stdout.write(send_command('MEMBERS') + '\n')
    sys.exit()


getlog_parser = argparse.ArgumentParser(
    description="Get log from peer address",
    prog='basefs getlog')

def getlog():
    bootstrap_parser.add_argument('addr',
        help='')
    bootstrap_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath)
    bootstrap_parser.add_argument('-f', '--force', dest='force', action='store_true',
        help='Rewrite log file if present')
    
    args = bootstrap_parser.parse_args()
    logpath = args.logpath
    if os.path.exists(logpath):
        if not args.force:
            sys.stderr.write("Error: logpath %s already exists and --force argument was not provided\n" % logpath)
            sys.exit(1)
        else:
            os.remove(logpath)
    ip, port = args.addr.split(':')
    port = int(port)
    with open(args.logpath, 'w') as handler:
        for data in utils.netcat(ip, port, b'g'):
            print(data)
            handler.write(data)


def help():
    commands = []
    max_key = 0
    for key in methods.keys():
        max_key = max(len(key), max_key)
    tabs = int((max_key+4)/8)
    for key, value in methods.items():
        method, parser = value
        head = '    ' + key
        indent = '\t'*(tabs - int(len(head)/8) + 1)
        commands.append(head + (indent + parser.description if parser else ''))
    sys.stdout.write(textwrap.dedent("""\
        Usage: basefs COMMAND [arg...]
               basefs [ --help | -v | --version ]
        
        Basically Available, Soft state, Eventually consistent File System.
        
        Commands:
        %s
        
        Run 'basefs COMMAND --help' for more information on a command
        """) % '\n'.join(commands))
    sys.exit()


methods = OrderedDict([
    ('mount', (mount, mount_parser)),
    ('handler', (handler, handler_parser)),
    ('bootstrap', (bootstrap, bootstrap_parser)),
    ('genkey', (genkey, genkey_parser)),
    ('keys', (keys, keys_parser)),
    ('grant', (grant, grant_parser)),
    ('revoke', (revoke, revoke_parser)),
    ('log', (log, log_parser)),
    ('revert', (revert, revert_parser)),
    ('blocks', (blocks, blocks_parser)),
    ('members', (members, members_parser)),
    ('getlog', (getlog, getlog_parser)),
    ('help', (help, None)),
])


if __name__ == '__main__':
    if len(sys.argv) > 1:
        method = sys.argv.pop(1)
        if method == '--help':
            method = 'help'
        try:
            method = methods[method][0]
        except KeyError:
            sys.stdout.write("Error: not recognized argument %s\n" % method)
            help()
            sys.exit(1)
    else:
        help()
        sys.exit(1)
    method()
