#!python
#coding: utf-8

# Only python 3
import sys
if sys.version_info[0] < 3:
    print("Must be using Python 3")
    sys.exit(1)

import awscli
import ruamel.yaml
from ruamel.yaml.error import YAMLError

import importlib
import os
import time
import re
import json
from urllib.parse import urlparse
from collections.abc import Mapping, Sequence
from collections import OrderedDict
import curses

data = {}

CONFIG_PATH = '~/.aws/'
PARAMS = 'endpointConfigurationTypes=REGIONAL'

RESPONSE_TYPES = [
'DEFAULT_4XX',
'DEFAULT_5XX',
'RESOURCE_NOT_FOUND',
'UNAUTHORIZED',
'INVALID_API_KEY',
'ACCESS_DENIED',
'AUTHORIZER_FAILURE',
'AUTHORIZER_CONFIGURATION_ERROR',
'INVALID_SIGNATURE',
'EXPIRED_TOKEN',
'MISSING_AUTHENTICATION_TOKEN',
'INTEGRATION_FAILURE',
'INTEGRATION_TIMEOUT',
'API_CONFIGURATION_ERROR',
'UNSUPPORTED_MEDIA_TYPE',
'BAD_REQUEST_PARAMETERS',
'BAD_REQUEST_BODY',
'REQUEST_TOO_LARGE',
'THROTTLED',
'QUOTA_EXCEEDED',]

region = ""
config_api = ""
fail = False

# Load default responses from inputs into script from data_script file
def load_default_values():
    global data

    # Create the dir for .aws configuration
    os.makedirs(os.path.expanduser(CONFIG_PATH), exist_ok=True)

    if os.path.exists(os.path.expanduser(CONFIG_PATH)+'data_script'):
        data = json.load(open(os.path.expanduser(CONFIG_PATH)+'data_script', 'r'))
    else:
        data['deploy_name'] = 'prod'
        data['basePath'] = 'v1'
        data['arn_group'] = 'arn:aws:logs:sa-east-1:262209455936:log-group:'
        with open(os.path.expanduser(CONFIG_PATH)+'data_script', 'w') as f:
            json.dump(data, f, sort_keys=True, indent=4)

yaml = ruamel.yaml.YAML()  # this uses the new API]

# Encoding JSON
class OrderlyJSONEncoder(json.JSONEncoder):
    def default(self, o): # pylint: disable=E0202
        if isinstance(o, Mapping):
            return OrderedDict(o)
        elif isinstance(o, Sequence):
            return list(o)
        return json.JSONEncoder.default(self, o)

# Check if the object can be a JSON
def is_json(json_object):
    if type(json_object) is dict:
        return True
    try:
        json_object = json.loads(json_object)
    except ValueError:
        return False
    return True

# Convert yaml to json and return your new file with path 
def yaml_2_json(file):
    if '.json' in os.path.abspath(file):
        return os.path.abspath(file)
    with open(os.path.abspath(file), 'r') as stream:
        try:
            new_filename = os.path.abspath(file).replace('.yaml','')+'.json'
            datamap = yaml.load(stream)
            with open(new_filename, 'w') as output:
                output.write(OrderlyJSONEncoder(indent=2).encode(datamap))
        except YAMLError as exc:
            print(exc)
            return ''
    return os.path.abspath(new_filename)

# Convert yaml to json and return your new file with path 
def export_json_2_yaml(file):
    if '.yaml' in os.path.abspath(file):
        return os.path.abspath(file)
    try:
        new_filename = os.path.abspath(file).replace('.json','')+'.prod.yaml'
        data = json.loads(open(os.path.abspath(file)).read())
        with open(new_filename, 'w', encoding = "utf-8") as output:
            yaml.dump(data, output)
    except YAMLError as exc:
        print(exc)
        return ''
    return os.path.abspath(new_filename)

# Remove json used to create yaml 
def remove_json(file):
    if '.json' in os.path.abspath(file):
        os.remove(os.path.abspath(file))
        return os.path.abspath(file)

def domain(stdscr):
    try:
        domains = execute('aws apigateway get-domain-names', True)

        domains_names = [domain['domainName'] for domain in domains['items']]

        attributes = {}
        curses.init_pair(1, curses.COLOR_WHITE, curses.COLOR_BLACK)
        attributes['normal'] = curses.color_pair(1)

        curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_WHITE)
        attributes['highlighted'] = curses.color_pair(2)

        c = 0  # last character read
        option = 0  # current option
        while c != 10:  # Enter in ascii
            stdscr.erase()
            stdscr.addstr("Choose a option to use in custom domain:\n", curses.A_UNDERLINE)
            for i in range(len(domains_names)):
                if i == option:
                    attr = attributes['highlighted']
                else:
                    attr = attributes['normal']
                stdscr.addstr("{0}. ".format(i + 1))
                stdscr.addstr(domains_names[i] + '\n', attr)
            c = stdscr.getch()
            if c == curses.KEY_UP and option > 0:
                option -= 1
            elif c == curses.KEY_DOWN and option < len(domains_names) - 1:
                option += 1

        return domains_names[option]
    except KeyboardInterrupt:
        return ''

def select_domain():
    return curses.wrapper(domain)

# Get result from creation of API import, and show erros and a option to continue if have warnings
def get_result(result):
    global config_api, fail
    config_api = ''.join(result).strip()
    try:
        config_api = json.loads(config_api)
    except json.decoder.JSONDecodeError:
        while True:
            resp = input('\n\nIgnore warnings? if it was an error it can not be continued (y / n): ')
            if resp.lower() == 'y':
                fail = True
                return True
            elif resp.lower() == 'n':
                return False
    return False

# Get input stored
def get_stored_value(field):
    global data
    if field in data:
        return data[field]
    return 'unknow'

# Set new input and store in data_script file
def store_value(field, value):
    global data
    data[field] = value
    with open(os.path.expanduser(CONFIG_PATH)+'data_script', 'w') as f:
        json.dump(data, f, sort_keys=True, indent=4)

# Function to show input from user
def input_complete(message, field):
    default = get_stored_value(field)
    value = input("{} [{}]:".format(message, default))
    if value == '':
        return default
    store_value(field, value)
    return value

# Function to wait input value:
def input_wait(message):
    while True:
        val = input(message)
        if val != "":
            return val

# Exclude the OPTIONS from file after configuration
def exclude_options_file(filename):
    data_import = json.loads(open(filename).read())
    for path in data_import["paths"]:
        if "options" in data_import["paths"][path]: del data_import["paths"][path]["options"]
                
    with open(filename, 'w') as output:
        output.write(OrderlyJSONEncoder(indent=2).encode(data_import))
# Include the OPTIONS to all path's if not exists
def include_itens_file(filename):
    data_import = json.loads(open(filename).read())
    for path in data_import["paths"]:
        if not "options" in data_import["paths"][path]:
            data_import["paths"][path]["options"] = { "responses": { "200": {
                "description": "options generated automatically",
                "content": {
                    "application/json": {
                        "schema": {
                        "$ref": "#/components/schemas/Empty"
                        }
                    }
                }
            }
            }
        }
        if "post" in data_import["paths"][path]:
            if not "200" in data_import["paths"][path]["post"]["responses"]:
                data_import["paths"][path]["post"]["responses"]["200"] = {
                    "description": "post generated automatically",
                    "content": {
                        "application/json": {
                            "schema": {
                            "$ref": "#/components/schemas/Empty"
                            }
                        }
                    }
                }

    data_import["components"]["securitySchemes"]["apiAuth"] = {
        "type": "apiKey",
        "name": "x-api-key",
        "in": "header"
    }

    with open(filename, 'w') as output:
        output.write(OrderlyJSONEncoder(indent=2).encode(data_import))

# Modify servers option into documentation
def include_item_servers_file(filename, domain_name, basePath):
    data_import = json.loads(open(filename).read())

    data_import["servers"] = [
        {
            "url": "https://{}/{}".format(domain_name, basePath),
            "description": "API"
        }
    ]
    
    with open(filename, 'w') as output:
        output.write(OrderlyJSONEncoder(indent=2).encode(data_import))

def create_deploy(config_api, overwrite=False):
    
    deploy_name = input_complete('Insert Deploy Name', 'deploy_name')

    execute('aws apigateway create-deployment --rest-api-id {} --stage-name {}'.format(config_api['id'], deploy_name))

    print('Creating Deploy {}...'.format(deploy_name))

    # Enable Logs CloudWatch Group
    print('Enabling Logs CloudWatch Group...')
    execute('aws apigateway update-stage --rest-api-id {} --stage-name {} --patch-operations op=replace,path=/*/*/logging/loglevel,value=INFO'.format(config_api['id'], deploy_name))

    arn_group = ''
    cloud_groupname = ''
    if overwrite == False:
        # Inserting custom logs
        arn_group = input_complete('Insert Arn Log Group', 'arn_group')
        cloud_groupname = input_wait('Insert CloudWatch Group (example: LawSuit):')
    else:
        # Getting config custom logs
        data = execute('aws apigateway get-stages --rest-api-id {}'.format(config_api['id']), True)
        data = data['item'][0]['accessLogSettings']['destinationArn'].rsplit(":", 1)
        arn_group = data[0] + ":"
        cloud_groupname = data[1]

    print('Inserting CloudWatch Group...')
    
    execute('''aws apigateway update-stage --rest-api-id {} --stage-name {} --patch-operations {} '''.format(config_api['id'], deploy_name, '\'[ { "op" : "replace", "path" : "/accessLogSettings/destinationArn", "value" : "'+arn_group+cloud_groupname+'" } ]\''))

    execute('''aws apigateway update-stage --rest-api-id {} --stage-name {} --patch-operations {}'''.format(config_api['id'], deploy_name, '\'[ { "op" : "replace", "path" : "/accessLogSettings/format", "value" : "{ \\"api_id\\": \\"$context.apiId\\", \\"api_key_id\\": \\"$context.identity.apiKeyId\\",\\"http_method\\": \\"$context.httpMethod\\", \\"requestId\\": \\"$context.requestId\\", \\"resource_id\\": \\"$context.resourceId\\", \\"resourcePath\\": \\"$context.resourcePath\\", \\"stage\\": \\"$context.stage\\", \\"status\\": \\"$context.status\\"}" } ]\''))

    return deploy_name

# Create custom domain path
def custom_domain(config_api, deploy_name, filename):
    print('Configuring custom domain...')
    domain_name = select_domain()
    while True:
        try:
            basePath = input_complete('Insert Base Path', 'basePath')
            execute("aws apigateway create-base-path-mapping --domain-name {} --rest-api-id {} --stage {} --base-path {}".format(domain_name, config_api['id'], deploy_name, basePath), True)
            break
        except json.decoder.JSONDecodeError:
            print("This path is already in use. Insert another.")
            
    include_item_servers_file(filename, domain_name, basePath)

def gateway_responses(config_api):
    print('Configuring gateway responses...')

    for response in RESPONSE_TYPES:
        execute('aws apigateway put-gateway-response --rest-api-id {} --response-type {} --response-parameters "{}" --response-templates {}'.format(config_api['id'], response, 
        ''' {\\"gatewayresponse.header.Access-Control-Allow-Headers\\": \\"\'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token\'\\", \\"gatewayresponse.header.Access-Control-Allow-Methods\\": \\"\''''+'GET,POST,PUT,DELETE,OPTIONS'+'''\'\\", \\"gatewayresponse.header.Access-Control-Allow-Origin\\": \\"\'*\'\\"} ''',
        ''' '{"application/json": "{ \\"msg_errors\\": [{ \\"status\\": \\"GATEWAY_ERROR\\", \\"msg\\": $context.error.messageString }] }"}' '''))

def get_base_url(url):
    parsed_uri = urlparse(url)
    return '{uri.scheme}://{uri.netloc}'.format(uri=parsed_uri)
    

# get old config JSON from old endpoints
def get_old_config(config_api, old_endpoints):
    for item in old_endpoints['items']:
        if "resourceMethods" in item:
            for method in item["resourceMethods"]:
                if method != "OPTIONS":
                    try:
                        data = execute('aws apigateway get-integration --rest-api-id {} --resource-id {} --http-method {}'.format(config_api['id'], item['id'], method), True)
                        return data
                    except json.decoder.JSONDecodeError:
                        continue 
    return {}

# Remove malformed endpoins in case of failure
def remove_malformed_paths(endpoints):
    for item in endpoints['items']:
        execute('aws apigateway delete-resource --rest-api-id {} --resource-id {}'.format(config_api['id'], item['id']))

def make_gateway_deploy(config_api):
    gateway_responses(config_api)

    create_deploy(config_api)

# Configure API imported
def configure(config_api, filename, merge=False, overwrite=False, old_endpoints={}):
    endpoints = ''
    try:
        if not is_json(config_api):
            return
        endpoints = execute('aws apigateway get-resources --rest-api-id {}'.format(config_api['id']), True)

        # getting domain and key before configure endpoints

        key = ''
        domain = ''

        if not merge:
            key = input_wait('Insert Authorization key (example: asd09aus81923aas112):')
            domain = input_wait('Insert Domain (example: https://example.com):')
        else:
            for item in old_endpoints['items']:
                endpoints['items'].remove(item)

            data_api = get_old_config(config_api, old_endpoints)

            key = data_api['requestParameters']['integration.request.header.Authorization'].replace("\'", "")
            domain = get_base_url(data_api['uri'])

        codes = ""

        try:
            for item in endpoints['items']:
                if "resourceMethods" in item:

                    for method in item["resourceMethods"]:
                        print('\nPath: {}, Method: {}...'.format(item['path'], method))
    
                        # configure endpoints

                        codes = execute('aws apigateway update-method --rest-api-id {} --resource-id {} --http-method {} --patch-operations op="replace",path="/apiKeyRequired",value="{}"'.format(config_api['id'], item['id'], method, 'false' if method == 'OPTIONS' else 'true'), True)

                        print('Configuring endpoins...')

                        if method == "OPTIONS":
                            execute('aws apigateway put-integration --rest-api-id {0} --resource-id {1} --http-method {2} --type MOCK --integration-http-method {2} --uri \'{3}{4}\' --request-templates {5}'.format(config_api['id'], item['id'], method, domain, item['path'],
                            ''' '{"application/json": "{\\"statusCode\\": 200}"}' '''), True)
                        else:
                            execute('aws apigateway put-integration --rest-api-id {0} --resource-id {1} --http-method {2} --type HTTP_PROXY --integration-http-method {2} --uri \'{3}{4}\' --request-parameters "{5}" '.format(config_api['id'], item['id'], method, domain, item['path'], '''{ \\"integration.request.header.Authorization\\": \\"\''''+key+'''\'\\" }'''), True)
                            execute('aws apigateway put-integration-response --rest-api-id {} --resource-id {} --http-method {} --status-code 200 --selection-pattern "" --response-templates "{}"'.format(config_api['id'], item['id'], method, '''{ \\"application/json\\": \\"\\" }'''), True)
                        # Enabling CORS for methods

                        print('Enabling CORS...')

                        for responseCode in codes["methodResponses"]:
                            if responseCode == "200": 
                                execute('aws apigateway update-method-response --rest-api-id {} --resource-id {} --http-method {} --status-code {} --patch-operations op="add",path="/responseParameters/method.response.header.Access-Control-Allow-Origin",value="false"'.format(config_api['id'], item['id'], method, responseCode), True)
                            if responseCode == "200" and method == "OPTIONS":   
                                execute('aws apigateway update-method-response --rest-api-id {} --resource-id {} --http-method {} --status-code {} --patch-operations op="add",path="/responseParameters/method.response.header.Access-Control-Allow-Headers",value="false"'.format(config_api['id'], item['id'], method, responseCode), True)
                                execute('aws apigateway update-method-response --rest-api-id {} --resource-id {} --http-method {} --status-code {} --patch-operations op="add",path="/responseParameters/method.response.header.Access-Control-Allow-Methods",value="false"'.format(config_api['id'], item['id'], method, responseCode), True)

                        if method == "OPTIONS":
                            execute('aws apigateway put-integration-response --rest-api-id {} --resource-id {} --http-method {} --status-code 200 --selection-pattern "-" --response-parameters "{}" --response-templates "{}"'.format(config_api['id'], item['id'], method,
                             ''' {\\"method.response.header.Access-Control-Allow-Headers\\": \\"\'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token\'\\", \\"method.response.header.Access-Control-Allow-Methods\\": \\"\''''+','.join([x for x in item["resourceMethods"]])+'''\'\\", \\"method.response.header.Access-Control-Allow-Origin\\": \\"\'*\'\\"} ''', '''{\\"application/json\\": \\"\\"} '''), True)

                        print('\nPath: {}, Method: {}... Configured sucessfully!!!'.format(item['path'], method))

            # Create gateway responses if first time
            if not old_endpoints and overwrite == False:
                gateway_responses(config_api)
            
            deploy_name = create_deploy(config_api, overwrite)

            # create custom_domain if first time
            if not old_endpoints and overwrite == False:
                custom_domain(config_api, deploy_name, filename)

                print('\nDeploy {} Executed sucessfully!!!'.format(deploy_name))
            else:
                print('\nUpdate from {} Executed sucessfully!!!'.format(config_api['id']))

            exclude_options_file(filename)
            export_json_2_yaml(filename)
            remove_json(filename)
        except Exception as e:
            print(e, ". Some field was not filled correctly.")

            if not old_endpoints and overwrite == False:
                remove_api(config_api['id'])
                print("Some error occurred. The API created was removed. Check the messages and try again.")
            else:
                if merge:
                    remove_malformed_paths(endpoints)
                    print("Some error occurred. New malformed endpoints were automatically removed.")
                else:
                    print("Some error occurred. You need to fix the problem and recovery the API.")
            remove_json(filename)
    except (KeyboardInterrupt, BrokenPipeError):
        if not old_endpoints and overwrite == False:
            remove_api(config_api['id'])
            print("You cancel the operation. The API created was removed.")
        else:
            if merge:
                remove_malformed_paths(endpoints)
                print("You cancel the operation. New malformed endpoints were automatically removed.")
            else:
                print("Some error occurred. You need to fix the problem and recovery the API.")
        remove_json(filename)
        
# Remove api if some error happened
def remove_api(api_id):
    print(execute('aws apigateway delete-rest-api --rest-api-id {}'.format(api_id)))

# File with configuration to aws application
def configure_aws():
    if os.path.isfile(os.path.expanduser(CONFIG_PATH)+'credentials'):
        os.remove(os.path.expanduser(CONFIG_PATH)+'credentials', )
    with open(os.path.expanduser(CONFIG_PATH)+'credentials', "w") as f:
        f.write("[default]\n")
        f.write("aws_access_key_id = "+input("AWS Access Key ID:")+"\n")
        f.write("aws_secret_access_key = "+input("AWS Secret Access Key:")+"\n")

    if os.path.isfile(os.path.expanduser(CONFIG_PATH)+'config'):
        os.remove(os.path.expanduser(CONFIG_PATH)+'config', )
    with open(os.path.expanduser(CONFIG_PATH)+'config', "w") as f:
        f.write("[default]\n")
        f.write("region = "+input("Default region name:")+"\n")

def get_param_value(param, args):
    if len(args) <= 1:
        return None
    for i, arg in enumerate(args):
        if arg == param and len(args) > i+1:
            return args[i+1]
    return None

def get_param_config(param, args):
    if len(args) <= 1:
        return None
    for arg in args:
        if arg == param:
            return arg.replace('-', '')
    return None

def check_param_exists(param, args):
    if len(args) <= 1:
        return False
    for arg in args:
        if arg == param:
            return True
    return False

def execute(command, withResponse=False):
    global region
    if withResponse:
        return json.loads(''.join(os.popen(command)).strip())
    else:
        return ''.join(os.popen(command)).strip()

# Starts the script
def init():
    global config_api, PARAMS, fail

    # Create config API file and file credentials if not exists
    if not os.path.isfile(os.path.expanduser(CONFIG_PATH)+'credentials'):
        configure_aws()

    baseName = sys.argv[0]
    filepath = sys.argv[1]
    #region = get_param_value('--region', sys.argv)
    id_api = get_param_value('--id', sys.argv)
    mode = get_param_config('--overwrite', sys.argv) or get_param_config('--merge', sys.argv)
    config = check_param_exists('--config', sys.argv)
    gateway_response = get_param_config('--gateway-response', sys.argv)

    if len(sys.argv) <= 1:
        print("Tip: to set region use `aws configure set default.region [region]`")
        print("Use: {} [File .yaml or .json] [--id] [api-id] [--overwrite|--merge]".format(os.path.basename(baseName)))
    elif config:
        configure_aws()
    elif id_api and gateway_response: # gateway-responses
        make_gateway_deploy({'id': id_api})
    elif id_api: # merge or overrite
        if mode == 'overwrite' or mode == 'merge':
            if os.path.exists(filepath): # check file exists
                filename = yaml_2_json(filepath)
                include_itens_file(filename)

                old_endpoints = execute('aws apigateway get-resources --rest-api-id {}'.format(id_api), True)

                while get_result(os.popen("aws apigateway put-rest-api --rest-api-id {} --mode {} --body 'file://{}' {}".format(id_api, mode, filename, '--fail-on-warnings' if not fail else '--no-fail-on-warnings')).readlines()):
                    continue
                configure(config_api, filename, True if mode == 'merge' else False, True if mode == 'overwrite' else False, old_endpoints if mode == 'merge' else {})
            else:
                print("Invalid file.")
        else:
            print("Use: {} [File .yaml or .json] --id {} [--overwrite|--merge]".format(os.path.basename(baseName), id_api))
    else:
        if os.path.exists(filepath): # check file exists
            filename = yaml_2_json(filepath)
            include_itens_file(filename)
            while get_result(os.popen("aws apigateway import-rest-api --parameters {} --body 'file://{}' {}".format(PARAMS, filename, '--fail-on-warnings' if not fail else '--no-fail-on-warnings')).readlines()):
                continue
            configure(config_api, filename)
        else:
            print("Invalid file.")


load_default_values()
init()