#!python
#
# -*- coding: utf-8 -*-
"""
=========================================================================
AWS cli configurator for multi-account/multi-role environments
=========================================================================

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.

This program is a free software: you can redistribute it and/or modify it
under the terms of the GNU Lesser General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

You should have received a copy of the GNU Lesser General Public
License along with this program. If not, see <http://www.gnu.org/licenses/>.
==========================================================================

"""
import argparse
import configparser
import logging
import os
import random
import string
import sys
from pathlib import Path

import boto3
import yaml

# Authorship information
__project__ = 'AWS cli configurator for multi-account/multi-role environments'
__product__ = 'aws-cli-config'
__editor__ = 'PyCharm'
__author__ = 'Lorenzo Gatti'
__email__ = 'lg@lorenzogatti.me'
__copyright__ = "Copyright 2019 Lorenzo Gatti. All Rights Reserved"
__credits__ = ["Lorenzo Gatti"]
__license__ = "MIT"
__date__ = '30/11/19'
__maintainer__ = 'Lorenzo Gatti'
__status__ = "Development"

LOGLEVEL = logging.INFO
CUSTOM_FORMAT = "%(asctime)s | %(filename)s:%(lineno)d | %(funcName)s | %(levelname)s | %(message)s "
AWS_OUTPUT = "json"
AWS_DEFAULT_REGION = "eu-central-1"
VERSION = ()
# -------------------------------------------------------------------------------------
# Initialize logger
logger = logging.getLogger(__product__)
formatter = logging.Formatter(CUSTOM_FORMAT)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.setLevel(LOGLEVEL)
logger.addHandler(handler)

parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)

# ----------------------------------------
# Define CLI arguments
parser.add_argument('profile',
                    type=str,
                    default="",
                    help='the name of the AWS parent profile / name of the AWS Organization group',
                    nargs="?"
                    )

parser.add_argument('mfa',
                    type=str,
                    default='',
                    help='the MFA code generated with an external hardware/virtual device',
                    nargs="?"
                    )

parser.add_argument('-l', '--list',
                    dest="list_roles_accounts",
                    action='store_true',
                    help='list accounts and roles for an organization/profile',
                    )

parser.add_argument('--aws-cli-config-filepath',
                    type=str,
                    default=os.path.join(str(Path.home()), '.aws-cli-config.yml'),
                    help='filepath of the YML config file containing the multi-account/multi-role structure'
                    )

parser.add_argument('--max-role-duration',
                    type=int,
                    default=86400,
                    help='the duration (in seconds) of the AWS IAM role session'
                    )

parser.add_argument('-v', '--verbose',
                    dest="verbose",
                    action='store_true',
                    help='verbose mode'
                    )


# -------------------------------------------------------------------------------------
def randomStringDigits(stringLength=6):
    """Generate a random string of letters and digits """
    lettersAndDigits = string.ascii_letters + string.digits
    return ''.join(random.choice(lettersAndDigits) for i in range(stringLength))


def print_accounts_and_roles(accounts):
    # Loop over the accounts configured in the YML file
    for account in accounts:
        print("%s (%s)" % (account.get('profile_prefix'), account.get('account_id')))
        # Loop over the roles configured per each account
        for idx, item in enumerate(account.get("account_roles")):
            # Get required parameters per each account
            profile_prefix = str(account.get('profile_prefix'))
            # Create the keyword for the profile
            profile_name = profile_prefix + "_" + item
            print("[%i] %s (%s)" % (idx + 1, profile_name, account.get("account_roles").get(item)))

    exit(0)


def main(args):
    # ----------------------------------------
    # Load yaml config file or create a new empty one in case none is found
    yml_config__content = None
    try:
        yml_config__filepath = args.aws_cli_config_filepath
        with open(yml_config__filepath, 'r') as file:
            yml_config__content = yaml.load(file, Loader=yaml.FullLoader)
            logger.debug("Successfully parsed YML config file at %s " % yml_config__filepath)
    except Exception as e:
        logger.error("An error occurred while opening the yml configuration file: %s" % e)
        exit(1)

    # ----------------------------------------
    # Print list of accounts/roles
    if args.list_roles_accounts:
        print_accounts_and_roles(yml_config__content.get('profiles').get(str(args.profile)).get('accounts'))

    # ----------------------------------------
    # Load aws credential file
    aws_credential__content = None
    aws_credential__filepath = os.path.join(str(Path.home()), '.aws/credentials')
    try:
        aws_credential__content = configparser.ConfigParser()
        aws_credential__content.read(aws_credential__filepath)
        logger.debug("Successfully parsed AWS credential file at %s " % aws_credential__filepath)
    except Exception as e:
        logger.error("An error occurred while opening the aws credential file: %s" % e)
        exit(1)

    # ----------------------------------------
    # Load aws config file
    aws_config__content = None
    aws_config__filepath = os.path.join(str(Path.home()), '.aws/config')
    try:
        aws_config__content = configparser.ConfigParser()
        aws_config__content.read(aws_config__filepath)
        logger.debug("Successfully parsed AWS config file at %s " % aws_config__filepath)
    except Exception as e:
        logger.error("An error occurred while opening the aws credential file: %s" % e)
        exit(1)

    # ----------------------------------------
    # Loop through the source profiles at the root of the section 'profiles'
    # for source_profile in yml_config__content.get('profiles'):
    source_profile = str(args.profile)
    logger.debug("Searching for %s source account in credential file" % source_profile)

    aws_entry_profile = yml_config__content.get('profiles').get(source_profile).get('profile_default', 'default')
    aws_mfa_profile = yml_config__content.get('profiles').get(source_profile).get('profile_mfa', '')

    access_key = None
    aws_secret_access_key = None
    aws_output = None
    aws_region = None
    mfa_serial = None

    aws_entry_profile__contents = dict(aws_credential__content[aws_entry_profile])
    aws_mfa_profile__contents = dict(aws_credential__content[aws_mfa_profile])

    # Get credentials from AWS credentials file for each source_profile
    try:

        access_key = aws_entry_profile__contents.get("aws_access_key_id")
        aws_secret_access_key = aws_entry_profile__contents.get("aws_secret_access_key")
        aws_output = aws_entry_profile__contents.get("output", AWS_OUTPUT)
        aws_region = aws_entry_profile__contents.get("region", AWS_DEFAULT_REGION)

        # Get MFA serial from profile with MFA enabled
        mfa_serial = aws_mfa_profile__contents.get("aws_arn_mfa", None)

    except Exception as e:
        logger.warning("An error occurred while searching for the credentials for the following source profile: %s" % e)
        exit(1)

    # Initialize STS client
    sts_client = boto3.client("sts", aws_access_key_id=access_key, aws_secret_access_key=aws_secret_access_key)

    # ---------------------------------------------------------
    # Get session token via get_session_token if MFA is enabled
    # ---------------------------------------------------------
    role__access_key_id = None
    role__secret_access_key = None
    role__session_token = None

    if args.mfa:
        try:

            get_session_token__object = sts_client.get_session_token(DurationSeconds=args.max_role_duration,
                                                                     SerialNumber=mfa_serial,
                                                                     TokenCode=str(args.mfa))
            # Get temporary token
            role__access_key_id = get_session_token__object['Credentials']['AccessKeyId']
            role__secret_access_key = get_session_token__object['Credentials']['SecretAccessKey']
            role__session_token = get_session_token__object['Credentials']['SessionToken']

            # Set token to source profile credential section
            aws_credential__content.set(aws_mfa_profile, 'aws_session_token', role__session_token)
            aws_credential__content.set(aws_mfa_profile, 'aws_access_key_id', role__access_key_id)
            aws_credential__content.set(aws_mfa_profile, 'aws_secret_access_key', role__secret_access_key)

        except Exception as e:
            logger.error("An error occurred while requesting the session token for profile %s: %s" % (aws_mfa_profile, e))
            exit(1)

    # Loop over the accounts configured in the YML file
    for account in yml_config__content.get('profiles').get(source_profile).get('accounts'):

        # Loop over the roles configured per each account
        for account_role in account.get("account_roles"):

            # Get required parameters per each account
            account_id = str(account.get('account_id'))
            profile_prefix = str(account.get('profile_prefix'))
            default_region = str(account.get('default_region'))
            default_output = str(account.get('default_output'))
            account_source = str(account.get('account_source'))

            # Compose the ARN string for the role
            role_arn = "arn:aws:iam::" + account_id + ":role/" + account.get("account_roles")[account_role]
            role_session = "aws-cli-config-robot-" + randomStringDigits(16)

            # Create the keyword for the profile
            profile_name = profile_prefix + "_" + account_role

            # Assume the role (if MFA is not enabled)
            if not args.mfa:
                try:
                    logger.debug("Attempt to assume the role %s with session %s" % (role_arn, role_session))

                    assumed_role__object = sts_client.assume_role(RoleArn=role_arn,
                                                                  RoleSessionName=role_session,
                                                                  DurationSeconds=args.max_role_duration)

                    role__access_key_id = assumed_role__object['Credentials']['AccessKeyId']
                    role__secret_access_key = assumed_role__object['Credentials']['SecretAccessKey']
                    role__session_token = assumed_role__object['Credentials']['SessionToken']

                except Exception as e:
                    logger.error("An error occurred while assuming role %s for account %s: %s" % (role_arn, account_id, e))
                    continue

            # ------------------------------------------------------
            # Update AWS credential and config files
            # ------------------------------------------------------

            # Add section if it is not present in the credential file
            if profile_name not in aws_credential__content.sections():
                aws_credential__content.add_section(profile_name)

            aws_credential__content.set(profile_name, 'aws_access_key_id', role__access_key_id)
            aws_credential__content.set(profile_name, 'aws_secret_access_key', role__secret_access_key)
            aws_credential__content.set(profile_name, 'aws_session_token', role__session_token)

            if args.mfa:
                aws_credential__content.set(profile_name, 'aws_arn_mfa', mfa_serial)

            # Add default values -- region
            if default_region == "None":
                aws_credential__content.set(profile_name, 'region', aws_region)
            else:
                aws_credential__content.set(profile_name, 'region', aws_region)

            # Add default values -- output
            if default_output == "None":
                aws_credential__content.set(profile_name, 'output', aws_output)
            else:
                aws_credential__content.set(profile_name, 'output', default_output)

            # Add entry in aws config file
            profile_name = "profile " + profile_name
            # Add section if it is not present in the credential file
            if profile_name not in aws_config__content.sections():
                aws_config__content.add_section(profile_name)
            aws_config__content.set(profile_name, 'role_arn', role_arn)
            aws_config__content.set(profile_name, 'source_profile', account_source)

        # Output to credential file
        with open(aws_credential__filepath, 'w') as file:
            aws_credential__content.write(file)

        # Output to config file
        with open(aws_config__filepath, 'w') as file:
            aws_config__content.write(file)


if __name__ == '__main__':

    # Parse the CLI arguments
    args = parser.parse_args(sys.argv[1:])
    # Set log level to debug is verbose is set to true
    if args.verbose:
        logger.setLevel(logging.DEBUG)

    # Print headers
    logger.debug("%s" % (__project__))

    # Store value of OS ENV AWS_PROFILE
    current_os_env_aws_profile = os.getenv("AWS_PROFILE")

    try:
        # Reset AWS_PROFILE variable
        os.environ.pop('AWS_PROFILE', None)
        logger.debug("Unsetting OS Env variable AWS_PROFILE currently set to %s" % current_os_env_aws_profile)

        sys.exit(main(args))

    finally:
        if current_os_env_aws_profile:
            logger.debug("Reset OS Env variable AWS_PROFILE to previous value %s" % current_os_env_aws_profile)
            os.environ["AWS_PROFILE"] = current_os_env_aws_profile
