#!python
import argparse
import logging
import sys
from time import time
from pathlib import Path

import numpy as np
import redblackgraph as rb

from os import path
from redblackgraph.util.relationship_file_io import RelationshipFileReader, MAX_COLUMNS_EXCEL
from redblackgraph.sparse.csgraph import transitive_closure, avos_canonical_ordering, CycleError
from scipy.stats import describe


MAX_PRACTICAL_SIZE = 1500
MAX_ALLOWABLE_SIZE = 35000 #130000


def init_argparse():
    parser = argparse.ArgumentParser(
        description="rbg - analyze Red Black Graph. Reads rbg ingest files and runs various algorithms",
        add_help=False,
        usage="rbga -f <base-file>",
    )
    try:
        parser.add_argument("-f", "--basefile", metavar="<STR>", type=str, help="base file name", required=True)
        parser.add_argument("-h", "--hops", default=4, type=int, help="Number of hops to include in graph")
        parser.add_argument("-i", "--ingest-invalid", action="store_true", default=False,
                            help="In addition to vertex and edge, ingest the invalid edges")
        parser.add_argument("--invalid-filter", action="append",
                            default=["parent_cardinality"], type=str,
                            nargs="+", help="Invalid reasons to accept")
        parser.add_argument("-l", "--filter", action="append",
                            default=["BiologicalParent", "UntypedParent", "UnspecifiedParentType"], type=str,
                            nargs="+", help="Edge relationship type values to accept")
        parser.add_argument("-o", "--outdir", metavar="<STR>", type=str,
                            help="output directory (default is same directory as basefile")
        parser.add_argument("-r", "--replace", action="store_true", default=False,
                            help="Replace existing output files")
        parser.add_argument("-v", "--verbose", action="store_true", default=False,
                            help="Increase output verbosity [False]")

        subparsers = parser.add_subparsers()
        run_parser = subparsers.add_parser("run", help="run graph analysis")
        run_parser.add_argument("-a", "--append-analysis", metavar="<STR>", type=str,
                                help="filename to append analysis to")
        run_parser.add_argument("-c", "--canonical", action="store_true", default=False,
                                help="Run canonical ordering (requires closure)")
        run_parser.add_argument("-l", "--closure", action="store_true", default=False,
                                help="Run closure")
        run_parser.add_argument("--row-and-column-info", action="store_true", default=False,
                                help="Gather descriptive stats about the closed matrix' row and columns")
        run_parser.add_argument("-t", "--topological", action="store_true", default=False,
                                help="Run topological ordering")
        run_parser.set_defaults(func=run_analysis)

        # extract arguments from the command line
        parser.error = parser.exit
        arguments = parser.parse_args()
        return arguments
    except SystemExit as e:
        print(e)
        parser.print_help()
        sys.exit(2)


def write_results(args, writer, graph, name, description, key_permutation=None):
    outputfile = get_file(args, name)
    check_file(outputfile, args)
    logger.info(f"Writing {description} to {outputfile}")
    writer.write(graph, output_file=outputfile, key_permutation=key_permutation)
    logger.info("  Writing complete")


def format_descriptive_stats(stats):
    return f"max: {stats.minmax[1]:,.2f}, mean: {stats.mean:,.2f}, variance: {stats.variance:,.2f}, skewness: {stats.skewness:,.2f}, kurtosis: {stats.kurtosis:,.2f}"


def run_analysis(args):
    if args.outdir:
        writer = rb.RedBlackGraphWriter(reader)
    else:
        writer = None

    if args.topological:
        # the graph topologically ordered during ingestion
        write_results(args, writer, graph, "topological", "topological ordering")

    R_star = None
    R_canonical = None
    row_stats = None
    col_stats = None
    time_closure = 0
    time_canonical = 0
    vertex_info = None
    start = time()
    if args.closure or args.canonical:
        try:
            logger.info("Computing transitive closure")
            start = time()
            R_star = transitive_closure(graph).W
            time_closure = time() - start
            logger.info(f"  Closure computation complete. There are {np.count_nonzero(R_star):,} edges in the graph.")
            if args.row_and_column_info:
                logger.info(f"  Gathering descriptive stats on rows and columns of closed matrix")

                counts = (R_star != 0)
                non_zero_cols = counts.sum(0)
                non_zero_rows = counts.sum(1)
                row_stats = describe(non_zero_rows)
                col_stats = describe(non_zero_cols)

                logger.info(f"    Row info: {format_descriptive_stats(row_stats)}")
                logger.info(f"    Col info: {format_descriptive_stats(col_stats)}")

            if args.closure and writer:
                write_results(args, writer, R_star, "closure", "closure")

            if args.canonical:
                logger.info("Computing canonical ordering")
                start = time()
                R_canonical = avos_canonical_ordering(R_star)
                time_canonical = time() - start
                logger.info("  Canonical ordering complete")
                if writer:
                    write_results(args, writer, R_canonical.A, "canonical", "canonical ordering",
                                  key_permutation=R_canonical.label_permutation)
        except CycleError as e:
            if time_closure == 0:
                time_closure = time() - start
            reader.get_vertex_key()
            vertex_info = reader.get_vertex_key()[e.vertex]
            logger.error(f"  Error: cycle detected. {vertex_info} has a path to itself.")

    if args.append_analysis:
        comp_stats = describe([val for val in R_canonical.components.values()]) if R_canonical else None
        header_written = path.exists(args.append_analysis)
        with open(args.append_analysis, "a+") as file:
            if not header_written:
                file.write("#name,hops,multi parents,vertices,edges (simple),edges (closure),time (closure),time (canonical),mean,row max,row variance,row skewness,row kurtosis,col max,col variance,col skewness,col kurtosis,components,comp min,comp max,comp mean,comp variance,comp skewness,comp kurtosis,cycle vertex\n")
            file.write(f"{args.basename},"                                               # name
                       f"{args.hops},"                                                   # hops
                       f"{'T' if args.ingest_invalid else 'F'},"                         # multiple parents allowed
                       f"{graph.shape[0]},"                                              # vertices
                       f"{graph.nnz},"                                                   # edges (simple)
                       f"{np.count_nonzero(R_star) if R_star is not None else ''},"      # edges (closed)
                       f"{time_closure if time_closure else ''},"                                                # time (closure)
                       f"{time_canonical if time_canonical else ''},"                                              # time (canonical)
                       f"{row_stats.mean if row_stats is not None else ''},"             # mean
                       f"{row_stats.minmax[1] if row_stats is not None else ''},"        # row max
                       f"{row_stats.variance if row_stats is not None else ''},"         # row variance
                       f"{row_stats.skewness if row_stats is not None else ''},"         # row skewness
                       f"{row_stats.kurtosis if row_stats is not None else ''},"         # row kurtosis
                       f"{col_stats.minmax[1] if col_stats is not None else ''},"        # col max
                       f"{col_stats.variance if col_stats is not None else ''},"         # col variance
                       f"{col_stats.skewness if col_stats is not None else ''},"         # col skewness
                       f"{col_stats.kurtosis if col_stats is not None else ''},"         # col kurtosis
                       f"{comp_stats.nobs if comp_stats else ''},"                       # components
                       f"{comp_stats.minmax[0] if comp_stats is not None else ''},"      # comp min
                       f"{comp_stats.minmax[1] if comp_stats is not None else ''},"      # comp max
                       f"{comp_stats.mean if comp_stats is not None else ''},"           # comp mean
                       f"{comp_stats.variance if comp_stats is not None else ''},"       # comp variance
                       f"{comp_stats.skewness if comp_stats is not None else ''},"       # comp skewness
                       f"{comp_stats.kurtosis if comp_stats is not None else ''},"       # comp kurtosis
                       f"{vertex_info if vertex_info else ''},"                          # cycle vertex
                       f"\n")


def get_file(args, name: str, extension:str="xlsx", output=True):
    basefile = Path(args.basefile)
    basename = str(basefile.parts[-1])
    args.basename = basename
    hops = args.hops
    if output:
        outdir = Path(args.outdir)
        return outdir / f"{basename}.{hops}.{name}.{extension}"
    else:
        return Path(f"{basefile}.{name}.{extension}")


def check_file(file: Path, args, existance_required=False):
    if existance_required:
        if not file.exists():
            raise SystemError(f"{file} not found")
    elif not args.replace:
        if file.exists():
            raise SystemError(f"{file} already exists.")


if __name__ == '__main__':
    args = init_argparse()
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG if args.verbose else logging.INFO)
    logger = logging.getLogger(__name__)

    vertices_file = get_file(args, "vertices", "csv", output=False)
    edges_file = get_file(args, "edges", "csv", output=False)
    invalid_edges_file = get_file(args, "invalid.edges", "csv", output=False) if args.ingest_invalid else None
    ignore_file = get_file(args, "ignore.edges", "csv", output=False)
    if not ignore_file.exists():
        ignore_file = None

    check_file(vertices_file, args, True)
    check_file(edges_file, args, True)
    if args.ingest_invalid:
        check_file(invalid_edges_file, args, True)

    logger.info("Reading graph input files")

    reader = RelationshipFileReader(vertices_file, edges_file, args.hops, args.filter, invalid_edges_file,
                                    args.invalid_filter, ignore_file)
    graph: rb.sparse.rb_matrix = reader.read()

    if graph.shape[0] >= MAX_COLUMNS_EXCEL and graph.shape[0] <= MAX_ALLOWABLE_SIZE and args.outdir:
        logger.error(f"Trying to ingest a graph that exceeds the size excel can handle (Max: {MAX_COLUMNS_EXCEL:,}).")
        args.outdir = None
    if graph.shape[0] >= MAX_PRACTICAL_SIZE and graph.shape[0] <= MAX_ALLOWABLE_SIZE and args.outdir:
        logger.warning("This graph is on the large size. It will take a few seconds more to write the xlsx file.")

    logger.info(f"  Reading complete. There are {graph.nnz:,} edges in the graph.")

    if graph.shape[0] > MAX_ALLOWABLE_SIZE:
        # Beyond a certain size, allocating a numpy array will result in a SIGKILL.
        # On my macbook, that is 130,000 (MAX_ALLOWABLE_SIZE). I'm uncertain how to determine
        # this programmatically, but it can be determined empirically via the memtest script
        logger.error(f"Unable to process more than {MAX_ALLOWABLE_SIZE:,} vertices. Attempting to process "
                     f"{graph.shape[0]:,} vertices.")
        args.closure = args.canonical = False

    args.func(args)
