#!python

import argparse
import os
import string
import random
import subprocess
import sys
import platform
import shutil
import requests
import zipfile
import platform
import socket
import subprocess
import sys
import time
import yaml
import shutil

import distro

base_directory = os.path.join(os.environ["HOME"], ".aqueduct")
server_directory = os.path.join(os.environ["HOME"], ".aqueduct", "server")
ui_directory = os.path.join(os.environ["HOME"], ".aqueduct", "ui")

package_version = "0.0.4"
default_server_port = 8080

s3_server_prefix = (
    "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/server" % package_version
)
s3_ui_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/ui" % package_version

welcome_message = """
***************************************************
Your API Key: %s

The Web UI and the backend server are accessible at: http://%s:%d
***************************************************
"""


def update_config_yaml(file):
    s = string.ascii_uppercase + string.digits
    encryption_key = "".join(random.sample(s, 32))
    api_key = "".join(random.sample(s, 32))

    with open(file, "r") as sources:
        lines = sources.readlines()
    with open(file, "w") as sources:
        for line in lines:
            if "<BASE_PATH>" in line:
                sources.write(line.replace("<BASE_PATH>", server_directory))
            elif "<ENCRYPTION_KEY>" in line:
                sources.write(line.replace("<ENCRYPTION_KEY>", encryption_key))
            elif "<API_KEY>" in line:
                sources.write(line.replace("<API_KEY>", api_key))
            else:
                sources.write(line)
    print("Updated configurations.")


def execute_command(args, cwd=None):
    with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
        proc.communicate()
        if proc.returncode != 0:
            raise Exception("Error executing command: %s" % args)


def generate_version_file(file_path):
    with open(file_path, "w") as f:
        f.write(package_version)


# Returns a bool indicating whether we need to perform a version upgrade.
def require_update(file_path):
    if not os.path.isfile(file_path):
        return True
    with open(file_path, "r") as f:
        current_version = f.read()
        if package_version < current_version:
            raise Exception(
                "Attempting to install an older version %s but found existing newer version %s"
                % (package_version, current_version)
            )
        elif package_version == current_version:
            return False
        else:
            return True


def update_executable_permissions():
    os.chmod(os.path.join(server_directory, "bin", "server"), 0o755)
    os.chmod(os.path.join(server_directory, "bin", "executor"), 0o755)
    os.chmod(os.path.join(server_directory, "bin", "migrator"), 0o755)


def download_server_binaries(architecture):
    with open(os.path.join(server_directory, "bin/server"), "wb") as f:
        f.write(requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/server")).content)
    with open(os.path.join(server_directory, "bin/executor"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/executor")).content
        )
    with open(os.path.join(server_directory, "bin/migrator"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/migrator")).content
        )
    with open(os.path.join(server_directory, "bin/start-function-executor.sh"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, "bin/start-function-executor.sh")).content
        )
    with open(os.path.join(server_directory, "bin/install_sqlserver_ubuntu.sh"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, "bin/install_sqlserver_ubuntu.sh")).content
        )
    print("Downloaded server binaries.")


def setup_server_binaries():
    print("Downloading server binaries.")
    server_bin_path = os.path.join(server_directory, "bin")
    shutil.rmtree(server_bin_path, ignore_errors=True)
    os.mkdir(server_bin_path)

    system = platform.system()
    arch = platform.machine()
    if system == "Linux" and arch == "x86_64":
        print("Operating system is Linux with architecture amd64.")
        download_server_binaries("linux_amd64")
    elif system == "Darwin" and arch == "x86_64":
        print("Operating system is Mac with architecture amd64.")
        download_server_binaries("darwin_amd64")
    elif system == "Darwin" and arch == "arm64":
        print("Operating system is Mac with architecture arm64.")
        download_server_binaries("darwin_arm64")
    else:
        raise Exception(
            "Unsupported operating system and architecture combination: %s, %s" % (system, arch)
        )


def update_ui_version():
    print("Updating UI version to %s" % package_version)
    try:
        shutil.rmtree(ui_directory, ignore_errors=True)
        os.mkdir(ui_directory)
        generate_version_file(os.path.join(ui_directory, "__version__"))
        ui_zip_path = os.path.join(ui_directory, "ui.zip")
        with open(ui_zip_path, "wb") as f:
            # We detect whether the server is running on a SageMaker instance by checking if the
            # directory /home/ec2-user/SageMaker exists. This is hacky but we couldn't find a better
            # solution at the moment.
            if os.path.isdir(os.path.join(os.sep, "home", "ec2-user", "SageMaker")):
                f.write(requests.get(os.path.join(s3_ui_prefix, "sagemaker", "ui.zip")).content)
            else:
                f.write(requests.get(os.path.join(s3_ui_prefix, "default", "ui.zip")).content)
        with zipfile.ZipFile(ui_zip_path, "r") as zip:
            zip.extractall(ui_directory)
        os.remove(ui_zip_path)
    except Exception as e:
        print(e)
        shutil.rmtree(ui_directory, ignore_errors=True)
        exit(1)


def update_server_version():
    print("Updating server version to %s" % package_version)

    version_file = os.path.join(server_directory, "__version__")
    if os.path.isfile(version_file):
        os.remove(version_file)
    generate_version_file(version_file)

    setup_server_binaries()
    update_executable_permissions()

    execute_command(
        [os.path.join(server_directory, "bin", "migrator"), "--type", "sqlite", "goto", "9"]
    )


def update():
    if not os.path.isdir(base_directory):
        os.makedirs(base_directory)

    if not os.path.isdir(ui_directory) or require_update(os.path.join(ui_directory, "__version__")):
        update_ui_version()

    if not os.path.isdir(server_directory):
        try:
            directories = [
                server_directory,
                os.path.join(server_directory, "db"),
                os.path.join(server_directory, "storage"),
                os.path.join(server_directory, "storage", "operators"),
                os.path.join(server_directory, "vault"),
                os.path.join(server_directory, "bin"),
                os.path.join(server_directory, "config"),
                os.path.join(server_directory, "logs"),
            ]

            for directory in directories:
                os.mkdir(directory)

            update_server_version()

            with open(os.path.join(server_directory, "config/config.yml"), "wb") as f:
                f.write(requests.get(os.path.join(s3_server_prefix, "config/config.yml")).content)

            update_config_yaml(os.path.join(server_directory, "config", "config.yml"))

            with open(os.path.join(server_directory, "db/demo.db"), "wb") as f:
                f.write(requests.get(os.path.join(s3_server_prefix, "db/demo.db")).content)

            print("Finished initializing Aqueduct base directory.")
        except Exception as e:
            print(e)
            shutil.rmtree(server_directory, ignore_errors=True)
            exit(1)

    version_file = os.path.join(server_directory, "__version__")
    if require_update(version_file):
        try:
            update_server_version()
        except Exception as e:
            print(e)
            if os.path.isfile(version_file):
                os.remove(version_file)
            exit(1)

def execute_command(args, cwd=None):
    with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
        proc.communicate()
        if proc.returncode != 0:
            raise Exception("Error executing command: %s" % args)

def execute_command_nonblocking(args, cwd=None):
    return subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd)

def generate_welcome_message(expose, port):
    if not expose:
        expose_ip = "localhost"
    else:
        expose_ip = "<IP_ADDRESS>"
    apikey = get_apikey()
    return welcome_message % (apikey, expose_ip, port)

def is_port_in_use(port: int) -> bool:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(("localhost", port)) == 0

def start(expose, port):
    update()

    if port is None:
        server_port = default_server_port
        while is_port_in_use(server_port):
            server_port += 1
        if not server_port == default_server_port:
            print("Default port %d is occupied. Next available port is %d" % (default_server_port, server_port))
    else:
        server_port = int(port)
        print("Server will use the user-specified port %d" % server_port)

    if expose:
        popen_handle = execute_command_nonblocking(
            [
                os.path.join(server_directory, "bin", "server"), 
                "--config", 
                os.path.join(server_directory, "config", "config.yml"), 
                "--expose",
                "--port",
                str(server_port),
            ], 
        )
    else:
        popen_handle = execute_command_nonblocking(
            [
                os.path.join(server_directory, "bin", "server"), 
                "--config", 
                os.path.join(server_directory, "config", "config.yml"),
                "--port",
                str(server_port),
            ], 
        )
    return popen_handle, server_port

def install_postgres():
    execute_command([sys.executable, "-m", "pip", "install", "psycopg2-binary"])

def install_bigquery():
    execute_command([sys.executable, "-m", "pip", "install", "google-cloud-bigquery"])

def install_snowflake():
    execute_command([sys.executable, "-m", "pip", "install", "snowflake-sqlalchemy"])

def install_s3():
    execute_command([sys.executable, "-m", "pip", "install", "pyarrow"])

def install_mysql():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu" or distro.id() == "debian":
            execute_command(["sudo", "apt-get", "install", "-y", "python3-dev", "default-libmysqlclient-dev", "build-essential"])
        elif distro.id() == "centos" or distro.id() == "rhel":
            execute_command(["sudo", "yum", "install", "-y", "python3-devel", "mysql-devel"])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "install", "mysql"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "mysqlclient"])

def install_sqlserver():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu":
            execute_command(["bash", os.path.join(server_directory, "bin", "install_sqlserver_ubuntu.sh")])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "tap", "microsoft/mssql-release", "https://github.com/Microsoft/homebrew-mssql-release"])
        execute_command(["brew", "update"])
        execute_command(["HOMEBREW_NO_ENV_FILTERING=1", "ACCEPT_EULA=Y", "brew", "install", "msodbcsql17", "mssql-tools"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "pyodbc"])

def install(system):
    if system == "postgres":
        install_postgres()
    elif system == "bigquery":
        install_bigquery()
    elif system == "snowflake":
        install_snowflake()
    elif system == "s3":
        install_s3()
    elif system == "mysql":
        install_mysql()
    elif system == "sqlserver":
        install_sqlserver()
    else:
        raise Exception("Unsupported system: %s" % system)

def get_apikey():
    config_file = os.path.join(server_directory, "config", "config.yml")
    with open(config_file, "r") as f:
        try:
            return yaml.safe_load(f)['apiKey']
        except yaml.YAMLError as exc:
            print(exc)
            exit(1)

def apikey():
    print(get_apikey())

def clear():
    shutil.rmtree(base_directory, ignore_errors=True)

def version():
    print(package_version)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='The Aqueduct CLI')
    subparsers = parser.add_subparsers(dest="command")
    
    start_args = subparsers.add_parser('start', help=
                               '''This starts the Aqueduct server and the UI in a blocking
                               fashion. To background the process run aqueduct start &.

                               Add --expose <IP_ADDRESS> to access the Aqueduct service from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server running the Aqueduct service.
                               ''')
    start_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    start_args.add_argument('--port', dest="port", help="Specify the port on which the Aqueduct server runs.")
    server_args = subparsers.add_parser('server', help=
                                   '''[DEPRECATED] This starts the Aqueduct server in a
                                   blocking fashion. To background the process,
                                   run aqueduct server &.''')
    server_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    ui_args = subparsers.add_parser('ui', help=
                               '''[DEPRECATED] This starts the Aqueduct UI in a blocking
                               fashion. To background the process run aqueduct
                               ui &.

                               Add --expose <IP_ADDRESS> to access the UI from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server you are running on.
                               ''')
    ui_args.add_argument('--expose', dest="expose", 
                    help="The IP address of the server running Aqueduct.") 

    install_args = subparsers.add_parser('install', help=
                             '''Install the required library dependencies for
                             an Aqueduct connector to a third-party system.''')
    install_args.add_argument('system', nargs=1, help="Supported integrations: postgres, mysql, sqlserver, s3, snowflake, bigquery.")

    apikey_args = subparsers.add_parser('apikey', help="Display your Aqueduct API key.")
    clear_args = subparsers.add_parser('clear', help="Erase your Aqueduct installation.")
    version_args = subparsers.add_parser('version', help="Retrieve the package version number.")

    args = parser.parse_args()

    if args.command == "start":
        try:
            popen_handle, server_port = start(args.expose, args.port)
            time.sleep(1)
            terminated = popen_handle.poll()
            if terminated:
                print("Server terminated due to an error.")
            else:
                print(generate_welcome_message(args.expose, server_port))
                popen_handle.wait()
        except (Exception, KeyboardInterrupt) as e:
            print(e)
            print('\nTerminating Aqueduct service...')
            popen_handle.kill()
            print('Aqueduct service successfully terminated.')
    elif args.command == "server":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "install":
        install(args.system[0]) # argparse makes this an array so only pass in value [0].
    elif args.command == "ui":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "apikey":
        apikey()
    elif args.command == "clear":
        clear()
    elif args.command == "version":
        version()
    elif args.command is None:
        parser.print_help()
    else:
        print("Unsupported command:", args.command)
        sys.exit(1)
