#!python
# Copyright 2019-2022 DADoES, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License in the root directory in the "LICENSE" file or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import time
import anatools
import subprocess
from anatools.lib.channel import Channel, find_channelfile
from anatools.lib.print import print_color

home = os.path.expanduser('~')


def update_credentials(client, workspaces, volumes):
    awsprofiles = {}
    mountdata = {'workspaces': {}, 'volumes': {}}
    if os.path.isdir(f'{home}/.aws') and os.path.isfile(f'{home}/.aws/credentials'):
        with open(f'{home}/.aws/credentials', 'r') as awscredfile:
            lines = awscredfile.readlines()
            profile = '[default]'
            awsprofiles[profile] = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('[') and line.endswith(']'):
                    profile = line
                    awsprofiles[profile] = []
                else: awsprofiles[profile].append(line)
    for workspaceId in workspaces:
        data = client.mount_workspaces(workspaces=[workspaceId])
        if data is False:
            print_color(f'There was an error retrieving mount credential for workspaceId {workspaceId}, please contact Rendered.ai for support.', 'ff0000')
            workspaces.remove(workspaceId)
        awsprofiles[f'[renderedai-workspaces-{workspaceId}]'] = [
            f"aws_access_key_id={data['credentials']['accesskeyid']}",
            f"aws_secret_access_key={data['credentials']['accesskey']}",
            f"aws_session_token={data['credentials']['sessiontoken']}]" ]
        mountdata['workspaces'][workspaceId]=data
    for volumeId in volumes:
        data = client.mount_volumes(volumes=[volumeId])
        if data == False:
            print_color(f'There was an error retrieving mount credential for volumeId {volumeId}, please contact Rendered.ai for support.', 'ff0000')
            continue
        awsprofiles[f'[renderedai-volumes-{volumeId}]'] = [
            f"aws_access_key_id={data['credentials']['accesskeyid']}",
            f"aws_secret_access_key={data['credentials']['accesskey']}",
            f"aws_session_token={data['credentials']['sessiontoken']}]" ]
        mountdata['volumes'][volumeId]=data
    if not os.path.isdir(f'{home}/.aws'): os.mkdir(f'{home}/.aws')
    with open(f'{home}/.aws/credentials', 'w+') as awscredfile:
        for profile in awsprofiles.keys():
            if len(awsprofiles[profile]):
                awscredfile.write(profile+'\n')
                awscredfile.writelines([line + '\n' for line in awsprofiles[profile]])
    return mountdata

def mount_data(mountdata, path, verbose):
    if not os.path.isdir(f'{home}/.renderedai/workspaces/'): os.makedirs(f'{home}/.renderedai/workspaces/', exist_ok=True)
    if not os.path.isdir(f'{home}/.renderedai/volumes/'): os.makedirs(f'{home}/.renderedai/volumes/', exist_ok=True)
    if not os.path.isdir(path): os.mkdir(path)
    if not os.path.isdir(os.path.join(path, 'workspaces')): os.mkdir(os.path.join(path, 'workspaces'))
    if not os.path.isdir(os.path.join(path, 'volumes')): os.mkdir(os.path.join(path, 'volumes'))
    for workspaceId in mountdata['workspaces'].keys():
        for i in range(len(mountdata['workspaces'][workspaceId]['keys'])):
            print(f'Mounting workspace {workspaceId}...', end='')
            os.makedirs(f'{home}/.renderedai/workspaces/{workspaceId}', exist_ok=True)
            rw = ''
            if mountdata['workspaces'][workspaceId]['rw'][i] == 'r': rw = '-o allow_other -o umask=0002'
            command = f's3fs {mountdata["workspaces"][workspaceId]["keys"][i]} {home}/.renderedai/workspaces/{workspaceId} -o profile=renderedai-workspaces-{workspaceId} -o endpoint=us-west-2 -o url="https://s3-us-west-2.amazonaws.com" -o use_cache=/tmp/s3fs/workspaces/{workspaceId} {rw} -f -d'
            if verbose: subprocess.Popen(command, shell=True, preexec_fn=os.setsid)
            else:       subprocess.Popen(command, stdout=subprocess.DEVNULL, shell=True, preexec_fn=os.setsid)
            wdir = os.path.join(path, 'workspaces', workspaceId)
            if os.path.exists(wdir): os.unlink(wdir)
            os.symlink(f'{home}/.renderedai/workspaces/{workspaceId}', wdir)
            print('complete!')
    for volumeId in mountdata['volumes'].keys():
        for i in range(len(mountdata['volumes'][volumeId]['keys'])):
            print(f'Mounting volume {volumeId}...', end='')
            os.makedirs(f'{home}/.renderedai/volumes/{volumeId}', exist_ok=True)
            rw = ''
            if mountdata['volumes'][volumeId]['rw'][i] == 'r': rw = '-o allow_other -o umask=0002'
            command = f's3fs {mountdata["volumes"][volumeId]["keys"][i]} {home}/.renderedai/volumes/{volumeId} -o profile=renderedai-volumes-{volumeId} -o endpoint=us-west-2 -o url="https://s3-us-west-2.amazonaws.com" -o use_cache=/tmp/s3fs/volumes/{volumeId} {rw} -f -d'
            if verbose: subprocess.Popen(command, shell=True, preexec_fn=os.setsid)
            else:       subprocess.Popen(command, stdout=subprocess.DEVNULL, shell=True, preexec_fn=os.setsid)
            vdir = os.path.join(path, 'volumes', volumeId)
            if os.path.exists(vdir): os.unlink(vdir)
            os.symlink(f'{home}/.renderedai/volumes/{volumeId}', vdir)
            print('complete!')


def unmount_data(path):
    print(f'Unmounting volumes...', end='')
    current_pid = os.getpid()
    pids = [pid for pid in map(int, subprocess.check_output(["pgrep", "-f", "bin/anamount"]).split()) if pid != current_pid]
    for pid in pids:
        try: ump = subprocess.Popen(f'kill -9 {pid}',stdout=subprocess.PIPE, shell=True); ump.wait()
        except: pass
    try:
        pids = map(int, subprocess.check_output(["pidof", "s3fs"]).split())
        for pid in pids:
            try: ump = subprocess.Popen(f'kill -9 {pid}',stdout=subprocess.PIPE, shell=True); ump.wait()
            except: pass
    except subprocess.CalledProcessError: pass
    for volumeId in os.listdir(f'{home}/.renderedai/volumes/'):
        try:
            try: ump = subprocess.Popen(f'sudo umount -f {home}/.renderedai/volumes/{volumeId}', shell=True); ump.wait()
            except: pass
            try:  ump = subprocess.Popen(f'sudo rm -rf {home}/.renderedai/volumes/{volumeId}', shell=True); ump.wait()
            except: pass   
            os.unlink(os.path.join(path, 'volumes', volumeId))
        except: pass
    for workspaceId in os.listdir(f'{home}/.renderedai/workspaces/'):
        try:
            try: ump = subprocess.Popen(f'sudo umount -f {home}/.renderedai/workspaces/{workspaceId}', shell=True); ump.wait()
            except: pass
            try:  ump = subprocess.Popen(f'sudo rm -rf {home}/.renderedai/workspaces/{workspaceId}', shell=True); ump.wait()
            except: pass   
            os.unlink(os.path.join(path, 'workspaces', workspaceId))
        except: pass
    if os.path.isdir(f'{home}/.aws') and os.path.isfile(f'{home}/.aws/credentials'):
        awsprofiles = {}
        with open(f'{home}/.aws/credentials', 'r') as awscredfile:
            lines = awscredfile.readlines()
            profile = '[default]'
            awsprofiles[profile] = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('[') and line.endswith(']'):
                    profile = line
                    awsprofiles[profile] = []
                else: awsprofiles[profile].append(line)
        with open(f'{home}/.aws/credentials', 'w+') as awscredfile:
            for profile in awsprofiles.keys():
                if len(awsprofiles[profile]) and not profile.startswith('[renderedai-workspaces-') and not profile.startswith('[renderedai-volumes-'):
                    awscredfile.write(profile+'\n')
                    awscredfile.writelines([line + '\n' for line in awsprofiles[profile]])
    print('complete.')


def mount_loop(client, workspaces, volumes, path, verbose=0):
    while True:
        try:
            mountdata = update_credentials(client, workspaces, volumes)
            mount_data(mountdata, path, verbose)
            for i in range(350):
                seconds = 3500-(i*10)
                print(f'Remounting volumes in {seconds}s...', end='\r')
                time.sleep(10)
            unmount_data(path)
        except KeyboardInterrupt:
                unmount_data(path)
                sys.exit()


parser = argparse.ArgumentParser(
    description="""
Mount volumes from the Rendered.ai Platform.
    mount channel volumes:  anamount
                            anamount --channel channelName
    mount single volume:    anamount --volumes volumeId
    mount many volumes:     anamount --volumes volumeId1,volumeId2
    mount single workspace: anamount --workspace workspaceId
    mount many workspaces:  anamount --workspaces workspaceId1,workspaceId2
    unmount volumes:        anamount --unmount
""",
    formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--channel', type=str, default=None, help='The name of the channel, i.e. --channel satrgb.')
parser.add_argument('--volumes', type=str, default=None, help='A list of volumeIds to mount, i.e. --volumes volumeId1,volumeId2.')
parser.add_argument('--workspaces', type=str, default=None, help='A list of workspaceIds to mount, i.e. --workspaces workspaceId1,workspaceId2.')
parser.add_argument('--path', default=None, help='The path to mount the volumes to, i.e /ana/data/volumes/.')
parser.add_argument('--unmount', action='store_true')
parser.add_argument('--email', type=str, default=None)
parser.add_argument('--password', type=str, default=None)
parser.add_argument('--environment', type=str, default=None)
parser.add_argument('--endpoint', type=str, default=None)
parser.add_argument('--local', action='store_true')
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
if args.verbose: verbose = 'debug'
else: verbose = False
if args.path: path = args.path
else:
    if os.path.exists('/ana/data/'): path = '/ana/data/'
    else: raise Exception("No --path parameter provided, could not find /ana/data directory.")
if args.unmount: unmount_data(path); sys.exit(1)

workspaces = []
volumes = []
if args.workspaces:
    try:
        workspaces.extend(args.workspaces.replace('[', '').replace(']', '').split(','))
        workspaces = [w.strip() for w in workspaces]
    except: 
        print('Failed to parse --workspaces input, expecting a list of workspaceIds.');
        sys.exit(1)
elif args.volumes:
    try:
        volumes.extend(args.volumes.replace('[', '').replace(']', '').split(','))
        volumes = [v.strip() for v in volumes]
    except: 
        print('Failed to parse --volumes input, expecting a list of volumeIds.');
        sys.exit(1)
else:
    if args.channel is None: args.channel = find_channelfile()
    if args.channel:
        channel = Channel(args.channel)
        for package in channel.packages.keys():
            if channel.packages[package] is None:
                continue
            if 'volumes' in channel.packages[package]:
                for volumeId in channel.packages[package]['volumes'].keys():
                    if volumeId != 'local': volumes.append(channel.packages[package]['volumes'][volumeId])
workspaces = list(set(workspaces))
volumes = list(set(volumes))
if len(workspaces) == 0 and len(volumes) == 0:
    print('No mount targets specified, please specify --channel, --volumes or --workspaces.')
    sys.exit(1)
else:
    client = anatools.client(
        email=args.email, 
        password=args.password,
        environment=args.environment,
        endpoint=args.endpoint,
        local=args.local,
        interactive=False,
        verbose=verbose)
    for workspace in workspaces:
        workspace = client.get_workspaces(workspaceId=workspace)
        if workspace is False:
            print_color(f'Warning: Unable to mount workspace {workspaceId}, permission denied.', 'ff0000')
            workspaces.remove(workspaceId)
    for volumeId in volumes: 
        volume = client.get_volumes(volumeId=volumeId)
        if volume is False:
            print_color(f'Warning: Unable to mount volume {volumeId}, permission denied.', 'ff0000')
            volumes.remove(volumeId)
        elif volume[0]['permission'] not in ['read', 'write']:
            print_color(f'Warning: Unable to mount volume {volumeId}, insufficient permissions (view-only).', 'ff0000')
            volumes.remove(volumeId)
    print_color(f'This process to mount Volumes was successfully enabled and will remain open to refresh Volumes. Killing this process will unmount Volumes. Press CTRL+C or close the terminal to kill the process.', 'ffff00')
    mount_loop(client, workspaces, volumes, path, args.verbose)
