#!python
# Copyright 2019-2022 DADoES, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License in the root directory in the "LICENSE" file or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from re import sub
import sys
import time
import anatools
import subprocess
from anatools.lib.channel import Channel, find_channelfile
from anatools.lib.print import print_color

home = os.path.expanduser('~')


def update_credentials(client, volumes):
    awsprofiles = {}
    volumedata = {}
    if os.path.isdir(f'{home}/.aws') and os.path.isfile(f'{home}/.aws/credentials'):
        with open(f'{home}/.aws/credentials', 'r') as awscredfile:
            lines = awscredfile.readlines()
            profile = '[default]'
            awsprofiles[profile] = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('[') and line.endswith(']'):
                    profile = line
                    awsprofiles[profile] = []
                else: awsprofiles[profile].append(line)
    for volumeId in volumes:
        data = client.mount_volumes(volumes=[volumeId])
        if data == False:
            print_color(f'There was an error retrieving mount credential for volumeId {volumeId}, please contact Rendered.ai for support.', 'ff0000')
            continue
        awsprofiles[f'[renderedai-anamount-{volumeId}]'] = [
            f"aws_access_key_id={data['credentials']['accesskeyid']}",
            f"aws_secret_access_key={data['credentials']['accesskey']}",
            f"aws_session_token={data['credentials']['sessiontoken']}]" ]
        volumedata[volumeId]=data
    if not os.path.isdir(f'{home}/.aws'): os.mkdir(f'{home}/.aws')
    with open(f'{home}/.aws/credentials', 'w+') as awscredfile:
        for profile in awsprofiles.keys():
            if len(awsprofiles[profile]):
                awscredfile.write(profile+'\n')
                awscredfile.writelines([line + '\n' for line in awsprofiles[profile]])
    return volumedata

def mount_volumes(volumedata, path, verbose):
    if not os.path.isdir(f'{home}/.renderedai/volumes/'): os.makedirs(f'{home}/.renderedai/volumes/', exist_ok=True)
    if not os.path.isdir(path): os.mkdir(path)
    for volumeId in volumedata.keys():
        for i in range(len(volumedata[volumeId]['keys'])):
            print(f'Mounting volume {volumeId}...', end='')
            if not os.path.exists(f'{home}/.renderedai/volumes/{volumeId}'): os.mkdir(f'{home}/.renderedai/volumes/{volumeId}')
            rw = ''
            if volumedata[volumeId]['rw'][i] == 'r': rw = '-o allow_other -o umask=0002'
            command = f's3fs {volumedata[volumeId]["keys"][i]} {home}/.renderedai/volumes/{volumeId} -o profile=renderedai-anamount-{volumeId} -o endpoint=us-west-2 -o url="https://s3-us-west-2.amazonaws.com" -o use_cache=/tmp/s3fs/{volumeId} {rw} -f -d'
            if verbose: subprocess.Popen(command, shell=True, preexec_fn=os.setsid)
            else:       subprocess.Popen(command, stdout=subprocess.DEVNULL, shell=True, preexec_fn=os.setsid)
            voldir = os.path.join(path, volumeId)
            if os.path.exists(voldir): os.unlink(voldir)
            os.symlink(f'{home}/.renderedai/volumes/{volumeId}', voldir)
            print('complete!')


def unmount_volumes(path):
    print(f'Unmounting volumes...', end='')
    current_pid = os.getpid()
    pids = [pid for pid in map(int, subprocess.check_output(["pgrep", "-f", "bin/anamount"]).split()) if pid != current_pid]
    for pid in pids:
        try: ump = subprocess.Popen(f'kill -9 {pid}',stdout=subprocess.PIPE, shell=True); ump.wait()
        except: pass
    try:
        pids = map(int, subprocess.check_output(["pidof", "s3fs"]).split())
        for pid in pids:
            try: ump = subprocess.Popen(f'kill -9 {pid}',stdout=subprocess.PIPE, shell=True); ump.wait()
            except: pass
    except subprocess.CalledProcessError: pass
    for volumeId in os.listdir(f'{home}/.renderedai/volumes/'):
        try:
            try: ump = subprocess.Popen(f'sudo umount -f {home}/.renderedai/volumes/{volumeId}', shell=True); ump.wait()
            except: pass
            try:  ump = subprocess.Popen(f'sudo rm -rf {home}/.renderedai/volumes/{volumeId}', shell=True); ump.wait()
            except: pass   
            os.unlink(os.path.join(path, volumeId))
        except: pass
    if os.path.isdir(f'{home}/.aws') and os.path.isfile(f'{home}/.aws/credentials'):
        awsprofiles = {}
        with open(f'{home}/.aws/credentials', 'r') as awscredfile:
            lines = awscredfile.readlines()
            profile = '[default]'
            awsprofiles[profile] = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('[') and line.endswith(']'):
                    profile = line
                    awsprofiles[profile] = []
                else: awsprofiles[profile].append(line)
        with open(f'{home}/.aws/credentials', 'w+') as awscredfile:
            for profile in awsprofiles.keys():
                if len(awsprofiles[profile]) and not profile.startswith('[renderedai-anamount-'):
                    awscredfile.write(profile+'\n')
                    awscredfile.writelines([line + '\n' for line in awsprofiles[profile]])
    print('complete.')


def mount_loop(client, volumes, path, verbose=0):
    while True:
        try:
            volumedata = update_credentials(client, volumes)
            mount_volumes(volumedata, path, verbose)
            for i in range(350):
                seconds = 3500-(i*10)
                print(f'Remounting volumes in {seconds}s...', end='\r')
                time.sleep(10)
            unmount_volumes(path)
        except KeyboardInterrupt:
                unmount_volumes(path)
                sys.exit()


parser = argparse.ArgumentParser(
    description="""
Mount volumes from the Rendered.ai Platform.
    mount channel volumes:  anamount
                            anamount --channel channelName
    mount single volume:    anamount --volumes volumeId
    mount many volumes:     anamount --volumes volumeId1,volumeId2
    unmount volumes:        anamount --unmount
""",
    formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--channel', type=str, default=None, help='The name of the channel, i.e. --channel satrgb.')
parser.add_argument('--volumes', type=str, default=None, help='A list of volumeIds to mount, i.e. --volumes volumeId1,volumeId2.')
parser.add_argument('--path', default=None, help='The path to mount the volumes to, i.e /ana/data/volumes/.')
parser.add_argument('--unmount', action='store_true')
parser.add_argument('--email', type=str, default=None)
parser.add_argument('--password', type=str, default=None)
parser.add_argument('--environment', type=str, default=None)
parser.add_argument('--endpoint', type=str, default=None)
parser.add_argument('--local', action='store_true')
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
if args.verbose: verbose = 'debug'
else: verbose = False
if args.path: path = args.path
else:
    if os.path.exists('/ana/data/'): path = '/ana/data/volumes/'
    elif os.path.exists(os.path.join('/workspaces', os.listdir("/workspaces")[0], "data")): path = os.path.join('/workspaces', os.listdir("/workspaces")[0], "data", "volumes")
    else: raise Exception("No --path parameter provided, could not find /ana/data or /workspaces directories.")
if args.unmount: unmount_volumes(path); sys.exit(1)

volumes = []
if args.volumes:
    try:
        volumes.extend(args.volumes.replace('[', '').replace(']', '').split(','))
        volumes = [v.strip() for v in volumes]
    except: 
        print('Failed to parse --volumes input, expecting a list of volumeIds.');
        sys.exit(1)
else:
    if args.channel is None: args.channel = find_channelfile()
    if args.channel:
        channel = Channel(args.channel)
        for package in channel.packages.keys():
            if channel.packages[package] is None:
                continue
            if 'volumes' in channel.packages[package]:
                for volumeId in channel.packages[package]['volumes'].keys():
                    if volumeId != 'local': volumes.append(channel.packages[package]['volumes'][volumeId])
volumes = list(set(volumes))
if len(volumes) == 0:
    print('No volumes specified.')
    sys.exit(1)
else:
    client = anatools.client(
        email=args.email, 
        password=args.password,
        environment=args.environment,
        endpoint=args.endpoint,
        local=args.local,
        interactive=False,
        verbose=verbose)
    for volumeId in volumes: 
        volume = client.get_volumes(volumeId=volumeId)
        if volume is False:
            print_color(f'Warning: Unable to mount volume {volumeId}, permission denied.', 'ff0000')
            volumes.remove(volumeId)
        elif volume[0]['permission'] not in ['read', 'write']:
            print_color(f'Warning: Unable to mount volume {volumeId}, insufficient permissions (view-only).', 'ff0000')
            volumes.remove(volumeId)
    print_color(f'This process to mount Volumes was successfully enabled and will remain open to refresh Volumes. Killing this process will unmount Volumes. Press CTRL+C or close the terminal to kill the process.', 'ffff00')
    mount_loop(client, volumes, path, args.verbose)
