#!python
# Copyright 2019-2022 DADoES, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License in the root directory in the "LICENSE" file or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import subprocess
import sys
import anatools
import time
import yaml
from traceback import print_exc
from anatools.lib.print import print_color
from anatools.lib.channel import Channel, find_channelfile


def create_channel(client, organization, channel, name, volumes = []):
    try:
        remotechannel = None
        orgchannels = client.get_channels(organizationId=organization['organizationId'])
        channelnames = [c['name'] for c in orgchannels]
        if name in channelnames: 
            print_color(f"The channel {name} already exists in this orgnaization, would you like to enter another name or use that channel?", '91e600')
            print_color(f"  [0] Enter a new name", '91e600')
            print_color(f"  [1] Deploy to the {name} channel", '91e600')
            resp = input("Select an option: ")
            while resp not in ["0", "1"]:
                resp = input("Invalid input, please enter either 0 or 1: ")
            if resp == "0": 
                resp = input("Enter the name you'd like to use: ")
                while resp in channelnames:
                    resp = input("A channel with that name already exists, enter another name: ")
                name = resp
            else: remotechannel = [c for c in orgchannels if c['name'] == name][0]
        if not remotechannel:
            channelId = client.create_channel(organizationId=organization['organizationId'], name=name, volumes=volumes)
            remotechannel = client.get_channels(channelId=channelId)[0]
        else:
            changes = {}
            if sorted(volumes) != sorted(remotechannel['volumes']): changes['volumes'] = {"local": volumes, "remote": remotechannel['volumes']}
            if name and name != remotechannel['name']: changes['name'] = {"local": name, "remote": remotechannel['name']}
            if len(changes.keys()) > 0:
                print_color(f"The channel {remotechannel['name']} in the {organization['name']} organization has the following differences from the local channel:", '91e600')
                for change in changes.keys():
                    print_color(f"  {change}:   {changes[change]['remote']} (remote) -> {changes[change]['local']} (local)", '91e600')
                resp = input("Would you like to update the remote channel?")
                while resp.lower() not in ['y', 'yes', 'n', 'no']:
                    resp = input('Invalid input, please enter either "y" or "n": ')
                if resp.lower() in ['y', 'yes']:
                    client.edit_channel(channelId=remotechannel['channelId'], name=name, volumes=volumes)
                    remotechannel = client.get_channels(channelId=remotechannel['channelId'])[0]
        return remotechannel
    except Exception as e:
        print_color(f'Channel creation failed, please contact support@rendered.ai.', 'ff0000')
        print(e)


parser = argparse.ArgumentParser(
    description="""
Deploy a channel to the Rendered.ai Platform.
    deploy a channel:               anadeploy
    specify a channel file          anadeploy --channel channelName
    deploy to a specific channel    anadeploy --channelId channelId
    deploy with verbose logging:    anadeploy --verbose
""",
    formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--channel', default=None)
parser.add_argument('--channelId', default=None)
parser.add_argument('--email', default=None)
parser.add_argument('--password', default=None)
parser.add_argument('--environment', default=None)
parser.add_argument('--endpoint', default=None)
parser.add_argument('--local', action='store_true')
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
if args.verbose: verbose = 'debug'
else: verbose = False

# find the channel file
if args.channel is None: args.channel = find_channelfile()
if args.channel is None: print('No channel file was specified or found.'); sys.exit(1)
localchannel = Channel(args.channel)
volumes = []
for package in localchannel.packages.keys():
    if localchannel.packages[package] is None: continue
    if 'volumes' in localchannel.packages[package] and localchannel.packages[package]['volumes']:
        for volume in localchannel.packages[package]['volumes'].keys():
            if volume != 'local': volumes.append(localchannel.packages[package]['volumes'][volume])

# make sure schema is valid
try:
    print("Checking Channel Schema...", end="")
    result = subprocess.run(
        "anautils --mode=schema --output=/tmp/", 
        shell=True,
        check=True,  # Raise an exception if the command fails
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True)
    print("done.")
except subprocess.CalledProcessError as e:
    print(f"Schema check failed with return code {e.returncode}")
    print(e.stderr)
    sys.exit(1)

# make sure channel documentation is valid
try:
    print("Checking Channel Documentation...", end="")
    result = subprocess.run(
        "anautils --mode=help --output=/tmp/", 
        shell=True,
        check=True,  # Raise an exception if the command fails
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True)
    print("done.")
except subprocess.CalledProcessError as e:
    print(f"Documentation check failed with return code {e.returncode}")
    print(e.stderr)
    sys.exit(1)

try:
    client = anatools.client(
        email=args.email, 
        password=args.password,
        environment=args.environment, 
        endpoint=args.endpoint,
        interactive=False,
        local=args.local,
        verbose=verbose)

    # get remote channelId and verify access to deploy, or create a new channel
    remotechannels = []
    remoteorganization = None
    if args.channelId or localchannel.remotes and len(localchannel.remotes) > 0:
        if args.channelId: remotes = [args.channelId]
        else: remotes = localchannel.remotes
        for i, remote in enumerate(remotes):
            try: 
                c = client.get_channels(channelId=remote['channelId'])[0]
                organization = client.get_organizations(organizationId=c['organizationId'])[0]
                c['organization'] = organization
                c['index'] = i
                remotechannels.append(c)
            except Exception as e: 
                print_color(f"There was an issue retrieving remote channel with channelId={channelId}.", "ff0000")
                print(e)
    options = 2
    print_color(f'Choose an option for deploying the channel: ', '91e600')
    if len(remotechannels) > 0:
        try:
            for remote in remotechannels:
                print_color(f'  [{options-2}] Deploy to the {remote["name"]} channel in the {remote["organization"]["name"]} organization', '91e600')
                options += 1
        except: pass
    print_color(f'  [{options-2}] Deploy to a new channel', '91e600')
    print_color(f'  [{options-1}] Deploy to an existing channel', '91e600')
    selection = input(f'Please select an option: ')
    while selection not in [str(i) for i in range(options)]:
        selection = input(f'Invalid input, please enter a number between 0 and {options-1}: ')
    if remotechannels and selection in [str(i) for i in range(len(remotechannels))]:
        remotechannel = remotechannels[int(selection)]
        local = localchannel.remotes[remotechannel['index']]
        name = local.get('name', None)
        description = local.get('description', None)
        instance = local.get('instance', None)
        timeout = local.get('timeout', None)
        changes = {}
        if sorted(volumes) != sorted(remotechannel['volumes']): changes['volumes'] = {"local": volumes, "remote": remotechannel['volumes']}
        if name and name != remotechannel['name']: changes['name'] = {"local": name, "remote": remotechannel['name']}
        if description and description != remotechannel['description']: changes['description'] = {"local": description, "remote": remotechannel['description']}
        if instance and instance != remotechannel['instanceType']:  changes['instance'] = {"local": instance, "remote": remotechannel['instanceType']}
        if timeout and timeout != remotechannel['timeout']: changes['timeout'] = {"local": timeout, "remote": remotechannel['timeout']}
        if len(changes.keys()) > 0:
            print_color(f"The channel {remotechannel['name']} in the {organization['name']} organization has the following differences from the local channel:", '91e600')
            for change in changes.keys():
                print_color(f"  {change}:   {changes[change]['remote']} (remote) -> {changes[change]['local']} (local)", '91e600')
            resp = input("Would you like to update the remote channel? ")
            while resp.lower() not in ['y', 'yes', 'n', 'no']:
                resp = input('Invalid input, please enter either "y" or "n": ')
            if resp.lower() in ['y', 'yes']:
                print(f"Editing channel...{remotechannel['channelId']} {name} {description} {instance} {timeout} {volumes}")
                client.edit_channel(channelId=remotechannel['channelId'], name=name, volumes=volumes, description=description, instance=instance, timeout=timeout)
                print('done.')
                remotechannel = client.get_channels(channelId=remotechannel['channelId'])[0]
    else:
        organizations = sorted(client.get_organizations(), key=lambda x: x['name'].lower())
        if len(organizations) > 1:
            print_color("Choose an organization to deploy the channel to: ", '91e600')
            for i, organization in enumerate(organizations):
                organizationId = organization['organizationId']
                print_color(f"  [{i}] {organization['name']} {'' if not verbose else f'({organizationId})'}", '91e600')
            resp = input(f'Please select an organization: ')
            while resp not in [str(i) for i in range(len(organizations))]:
                resp = input(f'Invalid input, please enter a number between 0 and {len(organizations)-1}: ')
            remoteorganization = organizations[int(resp)]
        elif len(organizations) == 1: remoteorganization = organizations[0]
        else: print_color('You have no organizations to deploy to, please contact sales@rendered.ai.', 'ff0000'); sys.exit(1)
        if selection == str(options-2): remotechannel = create_channel(client=client, organization=remoteorganization, channel=localchannel, name=args.channel.split('/')[-1].split('.')[0], volumes=volumes)
        else:
            channels = sorted(client.get_channels(organizationId=remoteorganization['organizationId']), key=lambda x: x['name'].lower())
            if len(channels) > 0:
                print_color("Choose a channel to deploy to: ", '91e600')
                for i, channel in enumerate(channels):
                    channelId = channel["channelId"]
                    print_color(f" [{i+1}] {channel['name']} {'' if not verbose else f'({channelId})'}", '91e600')
                resp = input(f'Please select a channel: ')
                while resp not in [str(i) for i in range(len(channels)+1)]:
                    resp = input(f'Invalid input, please enter a number between 0 and {len(channels)}: ')
                remotechannel = channels[int(resp)-1]
                if volumes != remotechannel['volumes']:
                    client.edit_channel(channelId=remotechannel['channelId'], volumes=volumes)
                    remotechannel = client.get_channels(channelId=remotechannel['channelId'])[0]
            else: 
                resp = input('No channels found in this organization. Would you like to create a new channel?: ' )
                while resp.lower() not in ['y', 'yes', 'n', 'no']:
                    resp = input('Invalid input, please enter either "y" or "n": ')
                if resp.lower() in ['n', 'no']: print('Exiting...'); sys.exit()
                remotechannel = create_channel(client=client, organization=remoteorganization, channel=localchannel, name=args.channel.split('/')[-1].split('.')[0], volumes=volumes)

    starttime = time.time()
    deploymentId = client.deploy_channel(channelId=remotechannel['channelId'], channelfile=args.channel)
    print('Registering Channel Image...', flush=True)
    registerstart = time.time()
    status = client.get_deployment_status(deploymentId=deploymentId, stream=True)
    if status['status']['state'] == 'Channel Deployment Failed': raise Exception(f"Channel deployment failed: {status['status']['message']}")
    else:   
        print(f'\033[1F\033[FRegistering Channel Image...done.  [{time.time()-registerstart:.3f}s]\033[K\n\033[K')
        print_color(f"The channel has been deployed and is ready to use!", '91e600')
        print(f'Deployment Time: {time.time()-starttime:.3f}s\n')

    remoteconfig = {
        "channelId": remotechannel['channelId'],
        "name": remotechannel['name'],
        "description": remotechannel['description'],
        "instance": remotechannel['instanceType'],
        "timeout": remotechannel['timeout']
    }
    remote_index = -1
    for i, remote in enumerate(localchannel.remotes):
        if remote.get('channelId') == remoteconfig['channelId']: remote_index = i; break
    saveconfig = False
    if len(localchannel.remotes) > 0:
        for key in localchannel.remotes[remote_index].keys():
            if localchannel.remotes[remote_index][key] != remoteconfig[key]: saveconfig = True
    if remote_index < 0 or saveconfig:
        resp = input('Would you like to save the remote config in your Channel for next time? ')
        if resp.lower() not in ['y', 'yes', 'n', 'no']: resp = input('Invalid input, please enter either "y" or "n": ') 
        if resp.lower() in ['y', "yes"]:
            if len(remoteconfig['description']) == 0: remoteconfig['description'] = "\"\""
            new_remote_lines = [
                f"  - channelId: {remoteconfig['channelId']}\n",
                f"    name: {remoteconfig['name']}\n",
                f"    description: {remoteconfig['description']}\n",
                f"    instance: {remoteconfig['instance']}\n",
                f"    timeout: {remoteconfig['timeout']}\n"]
            with open(args.channel, 'r') as yf: lines = yf.readlines()
            remotes_line = -1
            for i, line in enumerate(lines):
                if line.strip().startswith('remotes:'): remotes_line = i; break
            if remotes_line >= 0:
                startpos = remotes_line + 1
                endpos = remotes_line + 1
                channel = False
                for i in range(startpos, len(lines)):
                    if lines[i].strip().startswith('-'): startpos = i
                    if 'channelId: ' in lines[i] and (lines[i].split('channelId: ')[1].strip() == remoteconfig['channelId']): 
                        channel = True
                    if (lines[i].strip().startswith('-') and channel) or (len(lines[i]) and not lines[i].startswith(' ')):
                        endpos = i - 1
                        if not channel: startpos = endpos
                        break
                lines[startpos:endpos] = new_remote_lines
            else:
                lines.append("\n")
                lines.append("\n")
                lines.append("remotes:\n")
                for line in new_remote_lines: lines.append(line)
            with open(args.channel, 'w') as yf: yf.writelines(lines)
            print(f"Saved remote config to {args.channel}")

except KeyboardInterrupt:
    print_color('\nChannel deployment cancelled by user.', 'ff0000')
except Exception as e:
    print_color('\nChannel deployment failed, please notify support@rendered.ai.', 'ff0000')
    print_exc()