#!/usr/bin/env python3

from PW_explorer.dist_calc import *
from PW_explorer.pwe_helper import (
    load_from_temp_pickle,
    get_current_project_name,
    rel_id_from_rel_name,
    get_sql_conn,
    get_save_folder,
    get_file_save_name,
    set_current_project_name,
    import_custom_module,
)

import numpy as np
import pickle
import argparse
import os.path


def __main__():

    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--project_name", type=str, help="provide session/project name used while parsing")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-symmetric_difference", action='store_true', default=False,
                       help="this option measures distance by measuring the size of the symmetric difference set "
                            "of two PWs. Use either the -rel_ids or -rel_names flag to specify "
                            "the relations to use in this calculation.")
    group.add_argument("-euler_num_overlaps_diff", action='store_true', default=False,
                       help="use this if working with an euler result. This measures the distance as the "
                            "absolute difference in the number of overlaps (""><"") in two PWs. Provide the relation "
                            "name or relation id to use using the -rel_name or rel_id flag respectively. "
                            "Provide the column name to use using the -col flag.")
    group.add_argument("-custom_dist_func", type=str, nargs=2, default=[],
                       help="provide the module name and complete file path to the python module respectively "
                            "The function "
                            "signature should be dist(pw_id_1, pw_id_2, **kwargs) where kwargs contains the follwing:"
                            "dfs, pws, relations"
                            "where the latter three arguments refer to the data acquired from parsing the "
                            "ASP solutions. The function "
                            "should return a floating point number. Ensure that the file is in the same directory as "
                            "this script. You can use the functions in sql_funcs.py to design these dist functions.")
    group.add_argument("-show_relations", action='store_true', default=False,
                       help="to get a list of relations and corresponding relation ids.")

    parser.add_argument("-rel_names", nargs='*', default=[], type=str,
                        help="provide the relation names to use in the distance calculation.")
    # parser.add_argument("-rel_ids", nargs='*', default=[], type=int,
    #                     help="provide the relation ids of the relation to use in the distance calculation. To view "
    #                          "relation ids, use -show_relations")
    parser.add_argument("-rel_name", type=str,
                        help="provide the relation name to use in the distance calculation.")
    # parser.add_argument("-rel_id", type=int,
    #                     help="provide the relation id of the relation to use in the distance calculation. To view "
    #                          "relation ids, use -show_relations")
    parser.add_argument("-calc_dist_matrix", action='store_true', default=False,
                        help="specify this flag to calculate the distance matrix")
    parser.add_argument("-pws", type=int, nargs=2,
                        help="provide the two possible world ids of the possible world to calculate the distance between. "
                             "At least one of -pws and -calc_dist_matrix must be used.")
    parser.add_argument("-col", type=str,
                        help="provide the column to use for the distance calculation, required with the "
                             "euler_num_overlaps_diff distance metric.")

    args = parser.parse_args()

    project_name = ''
    if args.project_name is None:
        project_name = get_current_project_name()
        if project_name is None:
            print("Couldn't find current project. Please provide a project name.")
            exit(1)
    else:
        project_name = args.project_name

    dfs = load_from_temp_pickle(project_name, 'dfs')
    relations = load_from_temp_pickle(project_name, 'relations')
    pws = load_from_temp_pickle(project_name, 'pws')
    conn = get_sql_conn(project_name)
    expected_pws = len(pws)

    if args.symmetric_difference:

        # arg_ids = args.rel_ids

        if args.pws is None and args.calc_dist_matrix is False:
            print("Include at least one of -pws or -calc_dist_matrix flags.")
            exit(0)

        if args.pws is not None and len(args.pws) == 2:
            pw1 = args.pws[0]
            pw2 = args.pws[1]
            print("Distance between PWs {} and {} is {}".format(pw1, pw2,
                                                                sym_diff_dist(pw1, pw2, relations, dfs,
                                                                              pws, args.rel_names)))

        dist_matrix = None

        if args.calc_dist_matrix:
            dist_matrix = np.zeros((len(pws), len(pws)))

            for i in range(1, len(pws) + 1):
                for j in range(i, len(pws) + 1):
                    dist_matrix[i - 1, j - 1] = dist_matrix[j - 1, i - 1] = \
                        sym_diff_dist(i, j, relations, dfs, pws, args.rel_names)

            if np.max(dist_matrix) != np.min(dist_matrix):
                dist_matrix = (dist_matrix - np.min(dist_matrix)) / (np.max(dist_matrix) - np.min(dist_matrix))

            print("Distance Matrix:")
            print(str(dist_matrix))

            with open(get_save_folder(project_name, 'temp_pickle_data') + '/' +
                      get_file_save_name(project_name, 'dist_matrix'), 'wb') as f:
                pickle.dump(dist_matrix, f)

    elif args.euler_num_overlaps_diff:

        if args.rel_name is None:
            print("Please include either the -rel_name flag along with the appropriate argument.")
            exit(0)

        if args.col is None:
            print("-col is required")
            exit(0)

        if args.pws is None and args.calc_dist_matrix is False:
            print("Include atleast one of -pws or -calc_dist_matrix flags.")
            exit(0)

        if args.pws is not None and len(args.pws) == 2:
            pw1 = args.pws[0]
            pw2 = args.pws[1]
            print(
                "Distance between PWs {} and {} is {}".format(pw1, pw2,
                                                              euler_overlap_diff_dist(pw1, pw2, args.rel_name, args.col,
                                                                                      dfs, pws, relations)))

        dist_matrix = None

        if args.calc_dist_matrix:
            dist_matrix = np.zeros((len(pws), len(pws)))

            for i in range(1, len(pws) + 1):
                for j in range(i, len(pws) + 1):
                    dist_matrix[i - 1, j - 1] = dist_matrix[j - 1, i - 1] = \
                        euler_overlap_diff_dist(i, j, args.rel_name, args.col, dfs, pws, relations)

            if np.max(dist_matrix) != np.min(dist_matrix):
                dist_matrix = (dist_matrix - np.min(dist_matrix)) / (np.max(dist_matrix) - np.min(dist_matrix))

            print("Distance Matrix:")
            print(str(dist_matrix))

            with open(get_save_folder(project_name, 'temp_pickle_data') + '/' +
                      get_file_save_name(project_name, 'dist_matrix'), 'wb') as f:
                pickle.dump(dist_matrix, f)

    elif args.show_relations:

        print('Following are the parsed relation IDs and relation names:')
        for i, rl in enumerate(relations):
            print(str(i) + ':', str(rl.relation_name))

    elif args.custom_dist_func:

        try:
            [module_name, module_path] = args.custom_dist_func
            custom_viz_module = import_custom_module(module_name, os.path.abspath(module_path))
            dist_func = custom_viz_module.dist

        except Exception as e:
            print("Error importing from the given file")
            print("Error: ", str(e))
            exit(1)

        if args.pws is None and args.calc_dist_matrix is False:
            print("Include atleast one of -pws or -calc_dist_matrix flags.")
            exit(0)

        if args.pws is not None and len(args.pws) == 2:
            pw1 = args.pws[0]
            pw2 = args.pws[1]
            print("Distance between PWs {} and {} is {}".format(pw1, pw2, dist_func(pw1, pw2, dfs=dfs, pws=pws,
                                                                                    relations=relations)))

        if args.calc_dist_matrix:
            dist_matrix = np.zeros((len(pws), len(pws)))

            for i in range(1, len(pws) + 1):
                for j in range(i, len(pws) + 1):
                    dist_matrix[i - 1, j - 1] = dist_matrix[j - 1, i - 1] = \
                        dist_func(i, j, dfs=dfs, pws=pws, relations=relations)

            if np.max(dist_matrix) != np.min(dist_matrix):
                dist_matrix = (dist_matrix - np.min(dist_matrix)) / (np.max(dist_matrix) - np.min(dist_matrix))

            print("Distance Matrix:")
            print(str(dist_matrix))

            with open(get_save_folder(project_name, 'temp_pickle_data') + '/' +
                      get_file_save_name(project_name, 'dist_matrix'), 'wb') as f:
                pickle.dump(dist_matrix, f)

    set_current_project_name(project_name)
    conn.commit()
    conn.close()


if __name__ == '__main__':
    __main__()