#!/usr/bin/env python

import argparse
try:
    import configparser
except ImportError:
    import ConfigParser as configparser
import datetime
import logging
import os
import sys

import boto3

logger = logging.getLogger('botomfa')
stdout_handler = logging.StreamHandler(stream=sys.stdout)
stdout_handler.setFormatter(
    logging.Formatter('%(levelname)s - %(message)s'))
stdout_handler.setLevel(logging.DEBUG)
logger.addHandler(stdout_handler)
logger.setLevel(logging.DEBUG)

AWS_CREDS_PATH = '%s/.aws/credentials' % (os.path.expanduser('~'),)


def main():
    STS_DEFAULT = 900
    parser = argparse.ArgumentParser()
    parser.add_argument('--device',
                        required=False,
                        metavar='arn:aws:iam::123456788990:mfa/dudeman',
                        help="The MFA Device ARN. This value can also be "
                        "provided via the environment variable 'MFA_DEVICE'")
    parser.add_argument('--duration',
                        type=int,
                        help="The duration, in seconds, indicating how long "
                        "the temporary credentials should be valid. The "
                        "minimum is 900 seconds (15 minutes) and the maximum "
                        "is 3600 seconds (1 hour). This value can also be "
                        "provided via the environment variable "
                        "'MFA_STS_DURATION'.")
    parser.add_argument('--profile',
                        help="If using profiles, specify the name here. The "
                        "default profile name is 'default'",
                        required=False)
    parser.add_argument('--assume-role',
                        metavar='arn:aws:iam::123456788990:role/RoleName',
                        help="The ARN of the AWS IAM Role you would like to "
                        "assume, if specified. This value can also be provided "
                        "via the environment variable 'MFA_ASSUME_ROLE'",
                        required=False)
    parser.add_argument('--role-session-name',
                        help="Friendly session name required when using "
                        "--assume-role",
                        required=False)
    args = parser.parse_args()

    if not os.path.isfile(AWS_CREDS_PATH):
        sys.exit('Could not locate credentials file at %s' % (AWS_CREDS_PATH,))

    if not args.device:
        if os.environ.get('MFA_DEVICE'):
            args.device = os.environ.get('MFA_DEVICE')
        else:
            sys.exit('You must provide --device or MFA_DEVICE')

    if not args.duration:
        if os.environ.get('MFA_STS_DURATION'):
            args.duration = int(os.environ.get('MFA_STS_DURATION'))
        else:
            args.duration = STS_DEFAULT

    if not args.assume_role:
        if os.environ.get('MFA_ASSUME_ROLE'):
            args.assume_role = os.environ.get('MFA_ASSUME_ROLE')

    config = configparser.RawConfigParser()
    config.read(AWS_CREDS_PATH)

    validate(args, config)


def validate(args, config):
    short_term_name = 'default'
    if args.profile:
        short_term_name = args.profile

    long_term_name = '%s-long-term' % (short_term_name,)
    logger.info('Using profile: %s' % (short_term_name,))

    try:
        key_id = config.get(long_term_name, 'aws_access_key_id')
        access_key = config.get(long_term_name, 'aws_secret_access_key')
    except configparser.NoSectionError:
        logger.error(
            "Long term credentials session '[%s]' is missing. "
            "You must add this section to your credentials file "
            "along with your long term 'aws_access_key_id' and "
            "'aws_secret_access_key'" % (long_term_name,))
        sys.exit()
    except configparser.NoOptionError as e:
        logger.error(e)
        sys.exit()

    # Validate presence of short-term section
    if not config.has_section(short_term_name):
        if short_term_name == 'default':
            configparser.DEFAULTSECT = short_term_name
            if sys.version_info.major == 3:
                config.add_section(short_term_name)
            config.set(short_term_name, 'CREATE', 'TEST')
            config.remove_option(short_term_name, 'CREATE')
        else:
            config.add_section(short_term_name)
        get_credentials(short_term_name, key_id, access_key, args, config)
    else:
        try:
            expiration = config.get(short_term_name, 'Expiration')
            exp = datetime.datetime.strptime(expiration, '%Y-%m-%d %H:%M:%S')
            diff = exp - datetime.datetime.utcnow()
            if diff.total_seconds() <= 0:
                logger.info("Your credentials have expired, renewing.")
                get_credentials(
                    short_term_name, key_id, access_key, args, config)
            else:
                if config.getboolean(short_term_name, 'assumed_role'):
                    logger.info('You are assuming the role: %s'
                                % (config.get(
                                    short_term_name, 'assumed_role_arn')))
                logger.info(
                    'Your credentials are still valid for %s seconds'
                    ' they will expire at %s'
                    % (diff.total_seconds(), expiration))
        except configparser.NoOptionError:
            logger.info("Your credentials are invalid, renewing.")
            get_credentials(short_term_name, key_id, access_key, args, config)


def get_credentials(short_term_name, lt_key_id, lt_access_key, args, config):
    try:
        token_input = raw_input
    except NameError:
        token_input = input

    mfa_token = token_input('Enter AWS MFA code for device [%s] '
                            '(renewing for %s seconds):' %
                            (args.device, args.duration))

    client = boto3.client(
        'sts',
        aws_access_key_id=lt_key_id,
        aws_secret_access_key=lt_access_key
    )

    if args.assume_role:
        if args.role_session_name is None:
            logger.error("You must specify a role session name "
                         "via --role-session-name")
            sys.exit()
        response = client.assume_role(
            RoleArn=args.assume_role,
            RoleSessionName=args.role_session_name,
            DurationSeconds=args.duration,
            SerialNumber=args.device,
            TokenCode=mfa_token
        )
        config.set(
            short_term_name,
            'assumed_role',
            'True',
        )
        config.set(
            short_term_name,
            'assumed_role_arn',
            args.assume_role,
        )
    else:
        response = client.get_session_token(
            DurationSeconds=args.duration,
            SerialNumber=args.device,
            TokenCode=mfa_token
        )
        config.set(
            short_term_name,
            'assumed_role',
            'False',
        )
        config.remove_option(short_term_name, 'assumed_role_arn')

    # aws_session_token and aws_security_token are both added
    # to support boto and boto3
    options = [
        ('aws_access_key_id', 'AccessKeyId'),
        ('aws_secret_access_key', 'SecretAccessKey'),
        ('aws_session_token', 'SessionToken'),
        ('aws_security_token', 'SessionToken'),
    ]

    for option, value in options:
        config.set(
            short_term_name,
            option,
            response['Credentials'][value]
        )
    # Save expiration individiually, so it can be manipulated
    config.set(
        short_term_name,
        'expiration',
        response['Credentials']['Expiration'].strftime('%Y-%m-%d %H:%M:%S')
    )
    with open(AWS_CREDS_PATH, 'w') as configfile:
        config.write(configfile)
    logger.info(
        'Success! Your credentials will expire in %s seconds at: %s'
        % (args.duration, response['Credentials']['Expiration']))

if __name__ == "__main__":
    main()
