#!python
from __future__ import print_function
import argparse
import sys
import time
from multiprocessing.dummy import Pool as ThreadPool
from six.moves import urllib
import threading
import re

import boto3
from botocore.exceptions import ClientError

destination_bucket = None


def s_print(*a, **b):
    """Thread safe print function"""
    with s_print_lock:
        print(*a, **b)


class AtomicInteger:
    def __init__(self, value=0):
        self._value = value
        self._lock = threading.Lock()

    def inc(self, addition=1):
        with self._lock:
            self._value += addition

    def print_and_inc(self, records, limit):
        with self._lock:
            for record in records.split("\n"):
                if self.value() >= limit:
                    return
                s_print(record)
                self._value += 1

    def value(self):
        return self._value


def get_matching_s3_keys_and_sizes(s3_url):
    """
    Generate the keys in an S3 bucket.

    :param s3_url: URL of the prefix for files to restore
    """
    url_parsed = urllib.parse.urlparse(s3_url)

    if url_parsed.scheme != "s3":
        raise Exception("Prefix scheme must be s3!")

    bucket = url_parsed.netloc
    prefix = url_parsed.path.lstrip('/')

    s3 = boto3.client('s3')
    kwargs = {'Bucket': bucket}

    # If the prefix is a single string (not a tuple of strings), we can
    # do the filtering directly in the S3 API.
    if isinstance(prefix, str):
        kwargs['Prefix'] = prefix

    while True:

        # The S3 API response is a large blob of metadata.
        # 'Contents' contains information about the listed objects.
        resp = s3.list_objects_v2(**kwargs)
        if 'Contents' not in resp:
            raise Exception("No S3 files for prefix " + prefix)

        for obj in resp['Contents']:
            key = obj['Key']
            if key.startswith(prefix):
                yield "s3://" + bucket + "/" + key, obj['Size']

        # The S3 API is paginated, returning up to 1000 keys at a time.
        # Pass the continuation token into the next response, until we
        # reach the final page (when this field is missing).
        try:
            kwargs['ContinuationToken'] = resp['NextContinuationToken']
        except KeyError:
            break


def restore(to_restore):
    global destination_bucket
    key = to_restore['file']['key'].lstrip('/')
    bucket = to_restore['file']['bucket']
    s_print("Restoring s3://{}/{}".format(bucket, key))
    try:
        if destination_bucket is not None:
            boto3.client('s3') \
                .restore_object(
                Bucket=bucket, Key=key,
                RestoreRequest={'OutputLocation': {
                    'S3': {
                        'BucketName': destination_bucket,
                        'Prefix': key}},
                    'GlacierJobParameters': {
                        'Tier': to_restore['tier']}})

        else:
            boto3.client('s3') \
                .restore_object(
                Bucket=bucket, Key=key,
                RestoreRequest={'Days': to_restore['days'],
                                'GlacierJobParameters': {
                                    'Tier': to_restore['tier']}})
    except ClientError as e:
        if e.response['Error']['Code'] == 'RestoreAlreadyInProgress':
            s_print("Restore already in progress")
        elif e.response['Error']['Code'] == 'InvalidObjectState':
            s_print("Object's storage class is not GLACIER")
        elif e.response['Error']['Code'] == \
                'GlacierExpeditedRetrievalNotAvailable':
            s_print("Expedited retrieval not available. Sleeping for 1 minute")
            time.sleep(60)
            restore(to_restore)
        else:
            raise e


def check_status(to_check):
    global restored_count, not_on_glacier_count
    key = to_check['file']['key'].lstrip("/")
    bucket = to_check['file']['bucket']
    response = boto3.client('s3').head_object(Bucket=bucket, Key=key)
    path = "s3://" + bucket + "/" + key
    if 'StorageClass' not in response \
            or response['StorageClass'] != "GLACIER":
        s_print("Object {} storageClass is not Glacier".format(path))
        not_on_glacier_count.inc()
    else:
        if 'Restore' in response:
            if 'ongoing-request="false"' in response['Restore']:
                matches = re.search("expiry-date=\"(.*)\"", response['Restore'])
                s_print("Object {} is restored until {}".format(
                    path, matches.group(1)
                ))
                restored_count.inc()
            else:
                s_print("Object {} is being restored".format(path))
        else:
            s_print(
                "Object {} is in Glacier and not being restored".format(path))


def restore_main(
        s3_url,
        input_file,
        number_of_threads,
        days_to_keep,
        status_check):

    pool = ThreadPool(number_of_threads)

    s3_urls_and_sizes = []
    if s3_url is not None:
        s_print("Getting a listing of the files... ", end="")
        s3_urls_and_sizes = get_matching_s3_keys_and_sizes(s3_url)
    else:
        with open(input_file) as file:
            for line in file.readlines():
                s3_urls_and_sizes.append((line.rstrip("\n"), 0))

    total_size = 0
    to_restore = []

    for url_to_restore, size in s3_urls_and_sizes:
        url_parsed = urllib.parse.urlparse(url_to_restore)
        if url_parsed.scheme != "s3":
            raise Exception("Prefix scheme must be s3!")
        to_restore.append(dict(key=url_parsed.path, bucket=url_parsed.netloc))
        total_size = total_size + int(size)

    s_print("Done\n")

    if status_check:
        pool.map(
            check_status,
            map(lambda x: {**dict(file=x)}, to_restore))
        s_print("Restored count: {}/{}".format(
            restored_count.value(), len(to_restore)))
        if not_on_glacier_count.value() > 0:
            s_print("Also found: {} files with storage class that is not "
                    "Glacier. For them restore isn't needed"
                    .format(not_on_glacier_count.value()))
        return

    total_size_in_gb = total_size / 1024 ** 3
    s_print("About to restore {:0.2f}GB in {} files"
            .format(total_size_in_gb, len(to_restore)))

    prompt_template = """
Restore will cost us:
1) Expedited tier: ${expedited:0.2f}
2) Standard tier:  ${standard:0.2f}
3) Bulk tier:      ${bulk:0.2f}
Keeping files on S3 will cost: ${one_day_cost:0.2f} per day

Press number in front of an option you wish or any other key to exit: """
    prompt = prompt_template.format(
        expedited=total_size_in_gb * 0.03 + len(to_restore) * 10 / 1000,
        standard=total_size_in_gb * 0.01 + len(to_restore) * 0.05 / 1000,
        bulk=total_size_in_gb * 0.0025 + len(to_restore) * 0.025 / 1000,
        days=days_to_keep,
        one_day_cost=total_size_in_gb * 0.022 / 30
    )
    tier_nr = input(prompt)
    if tier_nr == "1":
        tier = "Expedited"
    elif tier_nr == "2":
        tier = "Standard"
    elif tier_nr == "3":
        tier = "Bulk"
    else:
        s_print("Chicken!")
        sys.exit(0)

    s_print("\nStarting restore using {} tier... ".format(tier), end="")
    # one thread ensures we can sleep if we hit Expedited throughput limit
    # anyway we don't have enough provisioned capacity to benefit from threading
    if tier == "Expedited":
        s_print("\nLowering number of threads to 1 as this is Expedited "
                "retrieval... ".format(tier))
        pool = ThreadPool(1)
    start = time.time()

    restore_args = dict(days=days_to_keep, tier=tier)

    pool.map(
        restore,
        map(lambda x: {**dict(file=x), **restore_args}, to_restore))
    end = time.time()

    s_print("Done")
    s_print("Time elapsed: {:0.0f}s".format(end - start))


if __name__ == '__main__':
    s_print_lock = threading.Lock()

    parser = argparse.ArgumentParser(
        description='Utility script to restore files on AWS S3 that have '
                    'GLACIER storage class')
    parser.add_argument(
        '-p',
        '--prefix',
        help='S3 prefix URL to restore'
    )
    parser.add_argument(
        '-i',
        '--input_file',
        help='Input file containing all s3 paths to restore'
    )
    parser.add_argument(
        '-d',
        '--days_to_keep',
        type=int,
        help='How many days to keep restored files'
    )
    parser.add_argument(
        '-D',
        '--destination_bucket',
        # TODO add here more functionality
        help='Restore to this bucket instead of to original bucket '
             'while preserving same path structure as in original bucket. This '
             'is useful if you don\'t know for how long you\'ll need restored '
             'files. Once you don\'t need them you can delete them from '
             'destination bucket.'
    )
    # 40 threads from Mac (home network) results in 2K files restored per minute
    parser.add_argument(
        '-t',
        '--threads',
        help='Number of threads to use. Default=40',
        type=int,
        default=40)
    parser.add_argument(
        '-s',
        '--status_print',
        help='Just print status of files and how many of them are in glacier '
             'and how many of them are restored already',
        action='store_true')
    parser.add_argument(
        "--profile",
        help="Use a specific AWS profile from your credential file."
    )

    args = vars(parser.parse_args())
    restored_count = AtomicInteger()
    not_on_glacier_count = AtomicInteger()

    if args['input_file'] is None and args['prefix'] is None:
        raise Exception("Must either specify input_file or prefix")

    if not args['status_print'] \
            and args['days_to_keep'] is None \
            and args['destination_bucket'] is None:
        raise Exception("Must either specify status_print or days_to_keep or "
                        "destination_bucket")

    destination_bucket = args['destination_bucket']

    if args['profile'] is not None:
        boto3.setup_default_session(profile_name=args['profile'])

    restore_main(
        args['prefix'],
        args['input_file'],
        args['threads'],
        args['days_to_keep'],
        args['status_print'])
