#!/usr/bin/env python

"""
Thin wrapper around the "aws" command line interface (CLI) for use
with LocalStack.

The "awslocalw" CLI allows you to easily interact with your local services
without having to specify "--endpoint-url=http://..." for every single command.

Example:
Instead of the following command ...
aws --endpoint-url=https://localhost:4568 --no-verify-ssl kinesis list-streams
... you can simply use this:
awslocalw kinesis list-streams

Options:
  Run "aws help" for more details on the aws CLI subcommands.
"""

import os
import sys
import subprocess
import re
from threading import Thread

from boto3.session import Session

PARENT_FOLDER = os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))
S3_VIRTUAL_ENDPOINT_HOSTNAME = 's3.localhost.localstack.cloud'
if os.path.isdir(os.path.join(PARENT_FOLDER, '.venv')):
    sys.path.insert(0, PARENT_FOLDER)

# names of additional environment variables to pass to subprocess
ENV_VARS_TO_PASS = ['PATH', 'PYTHONPATH', 'SYSTEMROOT', 'HOME', 'TERM', 'PAGER']

# service names without endpoints
NO_ENDPOINT_SERVICES = ('help', 'configure')

from localstack_client import config  # noqa: E402


def get_service():
    for param in sys.argv[1:]:
        if not param.startswith('-'):
            return param


def get_service_endpoint(localstack_host=None):
    service = get_service()
    if service == 's3api':
        service = 's3'
    endpoints = config.get_service_endpoints(localstack_host=localstack_host)
    # defaulting to use the endpoint for STS (could also be one of the other services in the existing list)
    # otherwise newly-added services in LocalStack would always need to be added to the _service_ports dict in localstack_client
    return endpoints.get(service) or endpoints.get("sts")


def usage():
    print(__doc__.strip())


def run(cmd, env=None):
    """
    Replaces this process with the AWS CLI process, with the given command and environment
    """
    if not env:
        env = {}

    # check for windows OS
    if sys.platform == 'win32':
        # The old ways.
        def output_reader(pipe, out):
            out_binary = os.fdopen(out.fileno(), 'wb')
            with pipe:
                for line in iter(pipe.readline, b''):
                    out_binary.write(line)
                    out_binary.flush()

        process = subprocess.Popen(
            cmd, env=env,
            stderr=subprocess.PIPE, stdout=subprocess.PIPE, stdin=subprocess.PIPE)

        Thread(target=output_reader, args=[process.stdout, sys.stdout]).start()
        # Why twice?
        Thread(target=output_reader, args=[process.stderr, sys.stderr]).start()

        process.wait()
        sys.exit(process.returncode)
    else:
        # The way that works on Linux, MacOS and WSL (which is a virtual machine running Linux)
        # Stdlib says this works on windows, but SO and my own experience says it doesn't
        os.execvpe(cmd[0], cmd, env)


def awscli_is_v1():
    try:
        from awscli import __version__ as awscli_version
        if re.match(r'^1.\d+.\d+$', awscli_version):
            return True
        return False
    except Exception:
        version = subprocess.check_output(['aws', '--version'])
        version = version.decode('UTF-8') if isinstance(version, bytes) else version
        return 'aws-cli/1' in version


def main():
    if len(sys.argv) > 1 and sys.argv[1] == '-h':
        return usage()
    try:
        import awscli.clidriver  # noqa: F401
    except Exception:
        return run_as_separate_process()
    patch_awscli_libs()
    run_in_process()


def prepare_environment():

    # prepare env vars
    env_dict = os.environ.copy()
    env_dict['PYTHONWARNINGS'] = os.environ.get(
        'PYTHONWARNINGS', 'ignore:Unverified HTTPS request')

    env_dict.pop('AWS_DATA_PATH', None)

    session = Session()
    credentials = session.get_credentials()

    if not credentials:
        env_dict['AWS_ACCESS_KEY_ID'] = 'test'
        env_dict['AWS_SECRET_ACCESS_KEY'] = 'test'

    if not session.region_name:
        env_dict['AWS_DEFAULT_REGION'] = 'us-east-1'

    # update environment variables in the current process
    os.environ.update(env_dict)

    return env_dict


def prepare_cmd_args():
    # get service and endpoint
    localstack_host = os.environ.get('LOCALSTACK_HOST')
    endpoint = get_service_endpoint(localstack_host=localstack_host)
    service = get_service()

    if not endpoint and service and service not in NO_ENDPOINT_SERVICES:
        msg = 'Unable to find LocalStack endpoint for service "%s"' % service
        print('ERROR: %s' % msg)
        return sys.exit(1)

    # prepare cmd args
    cmd_args = sys.argv
    if endpoint:
        cmd_args.insert(1, '--endpoint-url=%s' % endpoint)
        if 'https' in endpoint:
            cmd_args.insert(2, '--no-verify-ssl')
        # TODO: check the logic below and make it more resilient
        if 'cloudformation' in cmd_args and any(cmd in cmd_args for cmd in ['deploy', 'package']):
            if awscli_is_v1():
                cmd_args.insert(2, '--s3-endpoint-url=%s' % endpoint)
            else:
                print('!NOTE! awslocalw does not currently work with the cloudformation package/deploy commands supplied by '
                      'the AWS CLI v2. Please run "pip install awscli" to install version 1 of the AWS CLI')

    return list(cmd_args)


def run_as_separate_process():
    """
    Constructs a command line string and calls "aws" as an external process.
    """
    env_dict = prepare_environment()
    env_dict = {k: v for k, v in env_dict.items() if k.startswith('AWS_') or k in ENV_VARS_TO_PASS}

    cmd_args = prepare_cmd_args()
    cmd_args[0] = 'aws'

    # run the command
    run(cmd_args, env_dict)


def run_in_process():
    """
    Modifies the command line args in sys.argv and calls the AWS cli
    method directly in this process.
    """
    profile_name = (
        sys.argv[sys.argv.index('--profile') + 1] if '--profile' in sys.argv else 'default')

    endpoint_url = (
        sys.argv[sys.argv.index('--endpoint-url') + 1] if '--endpoint-url' in sys.argv else '')

    if not endpoint_url:
        endpoint_url = (
            sys.argv[sys.argv.index('--endpoint') + 1] if '--endpoint' in sys.argv else '')

    import botocore.session

    session = botocore.session.get_session()

    if S3_VIRTUAL_ENDPOINT_HOSTNAME not in endpoint_url:
        try:
            profiles = session.full_config.get('profiles')
            if profiles:
                current_profile = profiles.get(profile_name) or {}
                addressing_style = current_profile.get('s3', {}).get('addressing_style')
                if addressing_style in ['virtual', 'auto']:
                    msg = ("Addressing style is set to 'virtual' or 'auto' in the aws config. "
                           "Please change it to 'path'.")
                    print('WARNING: %s' % msg)
        except KeyError:
            pass

    import awscli.clidriver
    if os.environ.get('LC_CTYPE', '') == 'UTF-8':
        os.environ['LC_CTYPE'] = 'en_US.UTF-8'
    prepare_environment()
    prepare_cmd_args()
    sys.exit(awscli.clidriver.main())


def patch_awscli_libs():
    # TODO: Temporary fix until this PR is merged: https://github.com/aws/aws-cli/pull/3309

    import inspect
    from awscli import paramfile
    from awscli.customizations.cloudformation import deploy, package
    from botocore.serialize import Serializer

    # add parameter definitions
    if awscli_is_v1():
        paramfile.PARAMFILE_DISABLED.add('custom.package.s3-endpoint-url')
        paramfile.PARAMFILE_DISABLED.add('custom.deploy.s3-endpoint-url')

    s3_endpoint_arg = {
        'name': 's3-endpoint-url',
        'help_text': (
            'URL of storage service where packaged templates and artifacts'
            ' will be uploaded. Useful for testing and local development'
            ' or when uploading to a non-AWS storage service that is'
            ' nonetheless S3-compatible.'
        )
    }

    # add argument definition for S3 endpoint to use for CF package/deploy
    for arg_table in [deploy.DeployCommand.ARG_TABLE, package.PackageCommand.ARG_TABLE]:
        existing = [a for a in arg_table if a.get('name') == 's3-endpoint-url']
        if not existing:
            arg_table.append(s3_endpoint_arg)

    def wrap_create_client(_init_orig):
        """ Returns a new constructor that wraps the S3 client creation to use the custom endpoint for CF. """

        def new_init(self, session, *args, **kwargs):
            def create_client(*args, **kwargs):
                if args and args[0] == 's3':
                    # get stack frame of caller
                    curframe = inspect.currentframe()
                    calframe = inspect.getouterframes(curframe, 2)
                    fname = calframe[1].filename

                    # check if we are executing within the target method
                    is_target = (os.path.join('cloudformation', 'deploy.py') in fname
                                 or os.path.join('cloudformation', 'package.py') in fname)
                    if is_target:
                        if 'endpoint_url' not in kwargs:
                            args_passed = inspect.getargvalues(calframe[1].frame).locals
                            kwargs['endpoint_url'] = args_passed['parsed_args'].s3_endpoint_url
                return create_client_orig(*args, **kwargs)

            if not hasattr(session, '_s3_endpoint_patch_applied'):
                create_client_orig = session.create_client
                session.create_client = create_client
                session._s3_endpoint_patch_applied = True
            _init_orig(self, session, *args, **kwargs)

        return new_init

    deploy.DeployCommand.__init__ = wrap_create_client(deploy.DeployCommand.__init__)
    package.PackageCommand.__init__ = wrap_create_client(package.PackageCommand.__init__)

    # Apply a patch to botocore, to skip adding `data-` host prefixes to endpoint URLs, e.g. for:
    #  awslocalw servicediscovery discover-instances --service-name s1 --namespace-name ns1
    if hasattr(Serializer, "_expand_host_prefix"):
        config.patch_expand_host_prefix()


if __name__ == '__main__':
    main()
