#!python
import click
import collections
import time
from libcloud.compute.drivers.dimensiondata import DimensionDataNodeDriver
from libcloud.loadbalancer.drivers.dimensiondata import DimensionDataLBDriver
from libcloud.common.dimensiondata import DEFAULT_REGION

from dimensiondata.dr.utils import get_vm_mapping_from_file


CONTEXT_SETTINGS = dict(auto_envvar_prefix='DIDATA')


class DiDataCLIClient(object):
    def __init__(self):
        self.verbose = False
        self.force = False

    def init_client(self, user, password, region):
        self.node = DimensionDataNodeDriver(user, password, region=region)
        self.loadbalancer = DimensionDataLBDriver(user, password, region=region)

pass_client = click.make_pass_decorator(DiDataCLIClient, ensure=True)


@click.group(context_settings=CONTEXT_SETTINGS)
@click.option('--verbose', is_flag=True, default=False)
@click.option('--user', prompt=True)
@click.option('--password', prompt=True, hide_input=True)
@click.option('--region', default=DEFAULT_REGION)
@click.option('--force', is_flag=True, default=False)
@pass_client
def compare(client, verbose, user, password, region, force):
    client.init_client(user, password, region)
    client.verbose = verbose
    client.force = force


@compare.command()
@click.option('--primaryDc', type=click.UNPROCESSED, required=True, help='the master datacenter to compare')
@click.option('--secondaryDc', type=click.UNPROCESSED, required=True, help='the datacenter to compare to')
@click.option('--vmFile', required=True, help="the vm file to use")
@pass_client
def affinity_rules(client, primarydc, secondarydc, vmfile):
    vm_mapping = get_vm_mapping_from_file(vmfile)
    primary_nodes = client.node.list_nodes(ex_location=primarydc)
    secondary_nodes = client.node.list_nodes(ex_location=secondarydc)

    primary_node_dict = convert_nodes(primary_nodes)
    secondary_node_dict = convert_nodes(secondary_nodes)
    compare_nodes(primary_node_dict, secondary_node_dict, vm_mapping, primarydc, secondarydc, client.verbose)


@compare.command()
@click.option('--primaryDc', type=click.UNPROCESSED, required=True, help='the master datacenter to compare')
@click.option('--secondaryDc', type=click.UNPROCESSED, required=True, help='the datacenter to compare to')
@click.option('--vmFile', required=True, help="the vm file to use")
@pass_client
def servers_between_dcs(client, primarydc, secondarydc, vmfile):
    vm_mapping = get_vm_mapping_from_file(vmfile)
    primary_nodes = client.node.list_nodes(ex_location=primarydc)
    secondary_nodes = client.node.list_nodes(ex_location=secondarydc)

    primary_node_dict = convert_nodes(primary_nodes)
    secondary_node_dict = convert_nodes(secondary_nodes)
    compare_nodes(primary_node_dict, secondary_node_dict, vm_mapping, primarydc, secondarydc, client.verbose)


@compare.command()
@click.option('--primaryDC', type=click.UNPROCESSED, required=True, help='The master datacenter to compare')
@click.option('--secondaryDC', type=click.UNPROCESSED, required=True, help='The datacenter to compare to')
@click.option('--vmFile', required=True, help="The vm file to use")
@pass_client
def sync_servers_between_dcs(client, primarydc, secondarydc, vmfile):
    vm_mapping = get_vm_mapping_from_file(vmfile)
    primary_nodes = client.node.list_nodes(ex_location=primarydc)
    secondary_nodes = client.node.list_nodes(ex_location=secondarydc)

    primary_node_dict = convert_nodes(primary_nodes)
    secondary_node_dict = convert_nodes(secondary_nodes)
    sync_nodes(client, primary_node_dict, secondary_node_dict, vm_mapping, primarydc, secondarydc, client.verbose)


@compare.command()
@pass_client
def loadbalancers(client):
    balancers = client.loadbalancer.list_balancers()
    for balancer in balancers:
        members = client.loadbalancer.balancer_list_members(balancer)
        for member in members:
            print("{0}".format(member.ip))


@compare.command()
@click.option('--primaryNetDomain', type=click.UNPROCESSED, required=True, help='First Network Domain')
@click.option('--primaryLocation', type=click.UNPROCESSED, help='First Network Domain')
@click.option('--secondaryNetDomain', type=click.UNPROCESSED, required=True, help='Secondary Network Domain')
@click.option('--secondaryLocation', type=click.UNPROCESSED, help='First Network Domain')
@pass_client
def firewall_rules_between_net_domains(client, primarynetdomain, primarylocation,
                                       secondarynetdomain, secondarylocation):
    net_domains = client.node.ex_list_network_domains(location=primarylocation)
    filtered_net_domains = filter(lambda x: x.name == primarynetdomain, net_domains)
    if len(filtered_net_domains) > 1:
        click.secho("More than 1 domain found for primary network domain. "
                    "Please specify --primaryLocation for further searching", fg='red', bold=True)
        for domain in filtered_net_domains:
            click.secho("Name: {0} Id: {1} Location: {2}".format(domain.name, domain.id,
                                                                 domain.location.id), fg='red')
        exit(1)
    if not filtered_net_domains:
        click.secho("No network domain found with name {0}".format(primarynetdomain))
        exit(1)
    net_domain_id = filtered_net_domains[0].id
    net_domain_firewall_rules = client.node.ex_list_firewall_rules(net_domain_id)

    secondary_net_domains = client.node.ex_list_network_domains(location=secondarylocation)
    filtered_net_domains = filter(lambda x: x.name == secondarynetdomain, secondary_net_domains)
    if len(filtered_net_domains) > 1:
        click.secho("More than 1 domain found for secondary network domain. "
                    "Please specify --secondaryLocation for further searching", fg='red', bold=True)
        for domain in filtered_net_domains:
            click.secho("Name: {0} Id: {1} Location: {2}".format(domain.name, domain.id,
                                                                 domain.location.id), fg='red')
        exit(1)
    if not filtered_net_domains:
        click.secho("No network domain found with name {0}".format(primarynetdomain))
        exit(1)
    net_domain_id = filtered_net_domains[0].id
    secondary_net_domain_firewall_rules = client.node.ex_list_firewall_rules(net_domain_id)

    net_domain_firewall_rules_dict = convert_firewall_rules(net_domain_firewall_rules)
    secondary_net_domain_firewall_rules_dict = convert_firewall_rules(secondary_net_domain_firewall_rules)
    compare_firewall_rules(net_domain_firewall_rules_dict, secondary_net_domain_firewall_rules_dict)


def convert_firewall_rules(rules):
    rules_dict = collections.OrderedDict()
    for rule in rules:
        rule_str = flatten_firewall_rule(rule)
        rules_dict[rule_str] = rule
    return rules_dict


def flatten_firewall_rule(rule):
    return "{0}-{1}-{2}-{3}-{4}-{5}-{6}-{7}-{8}-{9}-{10}-{11}-{12}-{13}-{14}".format(
        rule.action, rule.ip_version, rule.protocol, rule.enabled, rule.status,
        rule.source.any_ip, rule.source.ip_address, rule.source.ip_prefix_size,
        rule.source.port_begin, rule.source.port_end,
        rule.destination.any_ip, rule.destination.ip_address, rule.destination.ip_prefix_size,
        rule.destination.port_begin, rule.destination.port_end
    )


def compare_firewall_rules(primary_rules, secondary_rules, verbose=False):
    primary_rules_tuple = list(primary_rules.items())
    secondary_rules_tuple = list(secondary_rules.items())

    count = 0
    for rule in primary_rules_tuple:
        matched_rule = False
        if count > len(secondary_rules_tuple) - 1:
            # This can't possibly work, so move to handling
            pass
        elif rule[0] == secondary_rules_tuple[count][0]:
            click.secho("Rule {0} matches between network domains".format(rule[0]), fg='green', bold=True)
            matched_rule = True

        if matched_rule is False:
            if rule[0] in secondary_rules:
                click.secho("Rule {0} is present but not in the same order"
                            " as in the secondary network domain".format(rule[0]),
                            fg='yellow', bold=True)
            else:
                click.secho("Could not find rule {0} from primary network domain"
                            " on secondary network domain".format(rule[0]),
                            fg='red', bold=True)
        count = count + 1
    for rule in secondary_rules:
        if rule not in primary_rules:
            click.secho("Extra rule {0} found in secondary network domain".format(rule), fg='red', bold=True)


def sync_nodes(client, master, secondary, vm_mapping, primary_dc, secondary_dc, verbose=False):
    items_to_sync = []
    for node in vm_mapping[primary_dc]:
        if node not in master:
            click.secho("Node {0} not found in DC {1}".format(node, primary_dc), fg='red')
            continue
        comp_node = vm_mapping[primary_dc][node]['vm']
        if verbose:
            click.secho("Node {0} in DC {1} should match Node {2} in DC {3}".format(
                node, primary_dc, comp_node, secondary_dc
            ))
        if comp_node in secondary:
            if verbose:
                click.secho("Node {0} exists in DC {1}".format(comp_node, secondary_dc))
            node_out_of_sync_items = get_out_of_sync_items(client, master[node], secondary[comp_node], verbose)
            items_to_sync.extend(node_out_of_sync_items)
        else:
            click.secho("Could not find node {0} in DC {1}".formaT(comp_node, secondary_dc))
    do_sync(client, items_to_sync)


def do_sync(client, items_to_sync):
    if items_to_sync:
        click.secho("Here are the items that need to sync:")
    for item_to_sync in items_to_sync:
        click.secho("ID: {0} Name: {1} Update: {2}".format(item_to_sync[0].id, item_to_sync[0].name, item_to_sync[1]))
        if item_to_sync[1] == 'update_disk_speed':
            click.secho("{0}->{1}".format(item_to_sync[2], item_to_sync[5]))
        elif item_to_sync[1] == 'update_disk_size':
            click.secho("{0}->{1}".format(item_to_sync[2], item_to_sync[5]))
        else:
            click.secho("{0}->{1}".format(item_to_sync[2], item_to_sync[3]))
    click.secho("")

    for item_to_sync in items_to_sync:
        sync_item(client, item_to_sync)


def sync_item(client, item):
    if not client.force:
        if not click.confirm("Changing item {0} {1}".format(item[0].id, item[1])):
            click.secho("Skipping item")
            return
    click.secho("Attempting to change item {0} {1}".format(item[0].id, item[1]))
    change_node = item[0]
    change_type = item[1]
    change_args = item[3:]
    if change_type == 'update_cpu_count':
        update_cpu_count(client, change_node, change_args[0])
    elif change_type == 'update_memory':
        update_memory(client, change_node, change_args[0])
    elif change_type == 'add_disk':
        add_disk(client, change_node, *change_args)
    elif change_type == 'update_disk_speed':
        update_disk_speed(client, change_node, *change_args)
    elif change_type == 'update_disk_size':
        update_disk_size(client, change_node, *change_args)


def add_disk(client, node, scsi_id, speed, size):
    client.node.ex_add_storage_to_node(node, size, speed, scsi_id)
    wait_for_disk_state(client, node, scsi_id)
    click.secho("Successfully added disk", fg='green', bold=True)


def update_disk_speed(client, node, scsi_id, disk_id, speed):
    client.node.ex_change_storage_speed(node, disk_id, speed)
    wait_for_disk_state(client, node, scsi_id)
    click.secho("Successfully added disk", fg='green', bold=True)


def update_disk_size(client, node, scsi_id, disk_id, size):
    client.node.ex_change_storage_size(node, disk_id, size)
    wait_for_disk_state(client, node, scsi_id)
    click.secho("Successfully added disk", fg='green', bold=True)


def update_cpu_count(client, node, count):
    client.node.ex_reconfigure_node(node, None, count, None, None)
    client.node.connection.wait_for_state('running', client.node.ex_get_node_by_id, 2, 300, node.id)
    click.secho("Successfully change CPU count", fg='green', bold=True)


def update_memory(client, node, count):
    memory_gb = int(count) / 1024
    client.node.ex_reconfigure_node(node, memory_gb, None, None, None)
    client.node.connection.wait_for_state('running', client.node.ex_get_node_by_id, 2, 300, node.id)
    click.secho("Successfully changed RAM to {0}GB".format(memory_gb), fg='green', bold=True)


def get_out_of_sync_items(client, n1, n2, verbose=False):
    click.secho("Comparing node {} and {}".format(n1['name'], n2['name']))
    items_to_sync = []
    update_node = client.node.ex_get_node_by_id(n2['id'])
    if n1['cpu_count'] != n2['cpu_count']:
        items_to_sync.append((update_node, 'update_cpu_count', n2['cpu_count'], n1['cpu_count']))
    if n1['memory'] != n2['memory']:
        items_to_sync.append((update_node, 'update_memory', n2['memory'], n1['memory']))
    if n1['disks'] != n2['disks']:
        for scsi_id in n1['disks']:
            if scsi_id not in n2['disks']:
                items_to_sync.append((update_node, 'add_disk', None, scsi_id, n1['disks'][scsi_id]['speed'], n1['disks'][scsi_id]['size_gb']))
            else:
                if n1['disks'][scsi_id]['speed'] != n2['disks'][scsi_id]['speed']:
                    items_to_sync.append((update_node, 'update_disk_speed', n2['disks'][scsi_id]['speed'], scsi_id, n2['disks'][scsi_id]['id'], n1['disks'][scsi_id]['speed']))
                if n1['disks'][scsi_id]['size_gb'] > n2['disks'][scsi_id]['size_gb']:
                    items_to_sync.append((update_node, 'update_disk_size', n2['disks'][scsi_id]['size_gb'], scsi_id, n2['disks'][scsi_id]['id'], n1['disks'][scsi_id]['size_gb']))
    if not items_to_sync:
        click.secho("Node {} and {} match!".format(n1['name'], n2['name']), fg='green', bold=True)
    else:
        click.secho("Node {} and {} do not match".format(n1['name'], n2['name']), fg='red', bold=True)
    return items_to_sync


def wait_for_disk_state(client, node, scsi_id):
    cnt = 0
    timeout = 600
    poll_interval = 2
    while cnt < timeout / poll_interval:
        node = client.node.ex_get_node_by_id(node.id)
        found_disk = find_disk_from_scsi_id(node, scsi_id)
        if found_disk.state == 'NORMAL':
            return
        time.sleep(poll_interval)
        cnt += 1


def find_disk_from_scsi_id(node, scsi_id):
    found_disk = None
    for disk in node.extra['disks']:
        if disk.scsi_id == scsi_id:
            found_disk = disk
            break
    return found_disk


def convert_nodes(node_list):
    node_dict = {}
    for node in node_list:
        node_dict[node.private_ips[0]] = {'cpu_count': node.extra['cpu'].cpu_count,
                                          'cpu_performance': node.extra['cpu'].performance,
                                          'cpu_cores_per_socket': node.extra['cpu'].cores_per_socket,
                                          'os_name': node.extra['OS_displayName'],
                                          'os_type': node.extra['OS_type'],
                                          'disks': convert_disks(node.extra['disks']),
                                          'memory': node.extra['memoryMb'],
                                          'name': node.name,
                                          'id': node.id}
    return node_dict


def compare_nodes(master, secondary, vm_mapping, primary_dc, secondary_dc, verbose=False):
    for node in vm_mapping[primary_dc]:
        if node not in master:
            click.secho("Node {0} not found in DC {1}".format(node, primary_dc), fg='red')
            continue
        comp_node = vm_mapping[primary_dc][node]['vm']
        if verbose:
            click.secho("Node {0} in DC {1} should match Node {2} in DC {3}".format(
                node, primary_dc, comp_node, secondary_dc
            ))
        if comp_node in secondary:
            if verbose:
                click.secho("Node {0} exists in DC {1}".format(comp_node, secondary_dc))
            if node_specs_match(master[node], secondary[comp_node], verbose):
                click.secho("Node matches", fg='green')
            else:
                click.secho("Node {} with ip {} does not match {} with ip {}".format(
                    master[node]['name'], node, secondary[comp_node]['name'], comp_node), fg='red')

            click.secho("")
        else:
            click.secho("Could not find node {0} in DC {1}".formaT(comp_node, secondary_dc))


def node_specs_match(n1, n2, verbose=False):
    click.secho("Comparing node {} and {}".format(n1['name'], n2['name']))
    match = True
    for key in n1.keys():
        if key is 'name':
            continue
        if key is 'id':
            continue
        if n1[key] != n2[key]:
            match = False
            click.secho("{} does not match {} vs {}".format(key, n1[key], n2[key]), fg='red')
        elif verbose:
            click.secho("{} matches for node".format(key), fg='green')
    return match


def convert_disks(disks):
    disk_dict = {}
    for disk in disks:
        disk_dict[disk.scsi_id] = {
            'size_gb': disk.size_gb,
            'speed': disk.speed,
            'id': disk.id
        }

    return disk_dict


if __name__ == '__main__':
    compare()
