#!python
import argparse
import logging
import sys
from time import time
from pathlib import Path
from typing import Union

import numpy as np
import redblackgraph as rb

from os import path
from fscrawler import RelationshipDbReader
from redblackgraph.util.relationship_file_io import RelationshipFileReader, MAX_COLUMNS_EXCEL
from redblackgraph.util import RbgGraphBuilder
from redblackgraph.sparse.csgraph import transitive_closure, avos_canonical_ordering, CycleError
from scipy.stats import describe
# noinspection PyProtectedMember
from scipy.stats._stats_py import DescribeResult


MAX_PRACTICAL_SIZE = 1500
MAX_ALLOWABLE_SIZE = 35000  # 130000


def init_argparse():
    parser = argparse.ArgumentParser(
        description="rbg - analyze Red Black Graph. Reads rbg ingest files and runs various algorithms",
        add_help=False,
        usage="rbg -f <base-file>",
    )
    try:
        parser.add_argument("-f", "--base_file", metavar="<STR>", type=str, help="base file name", required=True)
        parser.add_argument("-h", "--hops", default=4, type=int, help="Number of hops to include in graph")
        parser.add_argument("-i", "--ingest-invalid", action="store_true", default=False,
                            help="In addition to vertex and edge, ingest the invalid edges")
        parser.add_argument("--invalid-filter", action="append",
                            default=["parent_cardinality"], type=str,
                            nargs="+", help="Invalid reasons to accept")
        parser.add_argument("-l", "--filter", action="append",
                            default=["BiologicalParent", "UntypedParent", "UnspecifiedParentType"], type=str,
                            nargs="+", help="Edge relationship type values to accept")
        parser.add_argument("-o", "--outdir", metavar="<STR>", type=str,
                            help="output directory (default is same directory as base-file")
        parser.add_argument("-r", "--replace", action="store_true", default=False,
                            help="Replace existing output files")
        parser.add_argument("-v", "--verbose", action="store_true", default=False,
                            help="Increase output verbosity [False]")

        subparsers = parser.add_subparsers()
        run_parser = subparsers.add_parser("run", help="run graph analysis")
        run_parser.add_argument("-a", "--append-analysis", metavar="<STR>", type=str,
                                help="filename to append analysis to")
        run_parser.add_argument("-c", "--canonical", action="store_true", default=False,
                                help="Run canonical ordering (requires closure)")
        run_parser.add_argument("-l", "--closure", action="store_true", default=False,
                                help="Run closure")
        run_parser.add_argument("--row-and-column-info", action="store_true", default=False,
                                help="Gather descriptive stats about the closed matrix' row and columns")
        run_parser.add_argument("-t", "--topological", action="store_true", default=False,
                                help="Run topological ordering")
        run_parser.set_defaults(func=run_analysis)

        # extract arguments from the command line
        parser.error = parser.exit
        arguments = parser.parse_args()
        return arguments
    except SystemExit as e:
        print(e)
        parser.print_help()
        sys.exit(2)


def write_results(run_args, writer, graph, name, description, key_permutation=None):
    outputfile = get_file(run_args, name)
    check_file(outputfile, run_args)
    logger.info(f"Writing {description} to {outputfile}")
    writer.write(graph, output_file=outputfile, key_permutation=key_permutation)
    logger.info("  Writing complete")


def format_descriptive_stats(stats):
    return f"max: {stats.minmax[1]:,.2f}, mean: {stats.mean:,.2f}, variance: {stats.variance:,.2f}, " \
           f"skewness: {stats.skewness:,.2f}, kurtosis: {stats.kurtosis:,.2f}"


def run_analysis(run_args, graph, reader):
    if run_args.outdir:
        writer = rb.RedBlackGraphWriter(reader)
    else:
        writer = None

    if run_args.topological:
        # the graph topologically ordered during ingestion
        write_results(run_args, writer, graph, "topological", "topological ordering")

    r_star = None
    r_canonical = None
    row_stats: Union[DescribeResult, None] = None
    col_stats: Union[DescribeResult, None] = None
    time_closure = 0
    time_canonical = 0
    vertex_info = None
    start = time()
    if run_args.closure or run_args.canonical:
        try:
            logger.info("Computing transitive closure")
            start = time()
            r_star = transitive_closure(graph).W
            time_closure = time() - start
            logger.info(f"  Closure computation complete. There are {np.count_nonzero(r_star):,} edges in the graph.")
            if run_args.row_and_column_info:
                logger.info(f"  Gathering descriptive stats on rows and columns of closed matrix")

                counts = (r_star != 0)
                non_zero_cols = counts.sum(0)
                non_zero_rows = counts.sum(1)
                row_stats = describe(non_zero_rows)
                col_stats = describe(non_zero_cols)

                logger.info(f"    Row info: {format_descriptive_stats(row_stats)}")
                logger.info(f"    Col info: {format_descriptive_stats(col_stats)}")

            if run_args.closure and writer:
                write_results(run_args, writer, r_star, "closure", "closure")

            if run_args.canonical:
                logger.info("Computing canonical ordering")
                start = time()
                r_canonical = avos_canonical_ordering(r_star)
                time_canonical = time() - start
                logger.info("  Canonical ordering complete")
                if writer:
                    write_results(run_args, writer, r_canonical.A, "canonical", "canonical ordering",
                                  key_permutation=r_canonical.label_permutation)
        except CycleError as e:
            if time_closure == 0:
                time_closure = time() - start
            reader.get_vertex_key()
            vertex_info = reader.get_vertex_key()[e.vertex]
            logger.error(f"  Error: cycle detected. {vertex_info} has a path to itself.")

    if run_args.append_analysis:
        comp_stats: DescribeResult = describe([val for val in r_canonical.components.values()]) if r_canonical else None
        header_written = path.exists(run_args.append_analysis)
        with open(run_args.append_analysis, "a+") as file:
            if not header_written:
                file.write("#name,hops,multi parents,vertices,edges (simple),edges (closure),time (closure),"
                           "time (canonical),mean,row max,row variance,row skewness,row kurtosis,col max,"
                           "col variance,col skewness,col kurtosis,components,comp min,comp max,comp mean,"
                           "comp variance,comp skewness,comp kurtosis,cycle vertex\n")
            file.write(f"{run_args.basename},"                                               # name
                       f"{run_args.hops},"                                                   # hops
                       f"{'T' if run_args.ingest_invalid else 'F'},"                         # multiple parents allowed
                       f"{graph.shape[0]},"                                              # vertices
                       f"{graph.nnz},"                                                   # edges (simple)
                       f"{np.count_nonzero(r_star) if r_star is not None else ''},"      # edges (closed)
                       f"{time_closure if time_closure else ''},"                        # time (closure)
                       f"{time_canonical if time_canonical else ''},"                    # time (canonical)
                       f"{row_stats.mean if row_stats is not None else ''},"             # mean
                       f"{row_stats.minmax[1] if row_stats is not None else ''},"        # row max
                       f"{row_stats.variance if row_stats is not None else ''},"         # row variance
                       f"{row_stats.skewness if row_stats is not None else ''},"         # row skewness
                       f"{row_stats.kurtosis if row_stats is not None else ''},"         # row kurtosis
                       f"{col_stats.minmax[1] if col_stats is not None else ''},"        # col max
                       f"{col_stats.variance if col_stats is not None else ''},"         # col variance
                       f"{col_stats.skewness if col_stats is not None else ''},"         # col skewness
                       f"{col_stats.kurtosis if col_stats is not None else ''},"         # col kurtosis
                       f"{comp_stats.nobs if comp_stats else ''},"                       # components
                       f"{comp_stats.minmax[0] if comp_stats is not None else ''},"      # comp min
                       f"{comp_stats.minmax[1] if comp_stats is not None else ''},"      # comp max
                       f"{comp_stats.mean if comp_stats is not None else ''},"           # comp mean
                       f"{comp_stats.variance if comp_stats is not None else ''},"       # comp variance
                       f"{comp_stats.skewness if comp_stats is not None else ''},"       # comp skewness
                       f"{comp_stats.kurtosis if comp_stats is not None else ''},"       # comp kurtosis
                       f"{vertex_info if vertex_info else ''},"                          # cycle vertex
                       f"\n")


def get_file(run_args, name: Union[str, None], extension: str = "xlsx", output=True):
    base_file = Path(run_args.base_file)
    basename = str(base_file.parts[-1])
    run_args.basename = basename
    hops = run_args.hops
    if output:
        outdir = Path(run_args.outdir)
        if name:
            return outdir / f"{basename}.{hops}.{name}.{extension}"
        else:
            return outdir / f"{basename}.{hops}.{extension}"
    else:
        if name:
            return Path(f"{base_file}.{name}.{extension}")
        else:
            return Path(f"{base_file}.{extension}")


def check_file(file: Path, run_args, existence_required=False):
    if existence_required:
        if not file.exists():
            raise SystemError(f"{file} not found")
    elif not run_args.replace:
        if file.exists():
            raise SystemError(f"{file} already exists.")


def run(run_args):
    db_file = get_file(run_args, None, "db", output=False)
    if db_file.exists():
        logger.info(f"Reading graph from database {db_file}")
        reader = RelationshipDbReader(db_file, run_args.hops, RbgGraphBuilder())
    else:
        reader = init_file_reader(run_args)
    graph = reader.read()

    if MAX_COLUMNS_EXCEL <= graph.shape[0] <= MAX_ALLOWABLE_SIZE and run_args.outdir:
        logger.error(f"Trying to ingest a graph that exceeds the size excel can handle (Max: {MAX_COLUMNS_EXCEL:,}).")
        run_args.outdir = None
    if MAX_PRACTICAL_SIZE <= graph.shape[0] <= MAX_ALLOWABLE_SIZE and run_args.outdir:
        logger.warning("This graph is on the large size. It will take a few seconds more to write the xlsx file.")

    logger.info(f"  Reading complete. There are {graph.nnz:,} edges in the graph.")

    if graph.shape[0] > MAX_ALLOWABLE_SIZE:
        # Beyond a certain size, allocating a numpy array will result in SIGKILL.
        # On my macbook, that is 130,000 (MAX_ALLOWABLE_SIZE). I'm uncertain how to determine
        # this programmatically, but it can be determined empirically via the memtest script
        logger.error(f"Unable to process more than {MAX_ALLOWABLE_SIZE:,} vertices. Attempting to process "
                     f"{graph.shape[0]:,} vertices.")
        run_args.closure = run_args.canonical = False

    run_args.func(run_args, graph, reader)


def init_file_reader(run_args):
    vertices_file = get_file(run_args, "vertices", "csv", output=False)
    edges_file = get_file(run_args, "edges", "csv", output=False)
    invalid_edges_file = get_file(run_args, "invalid.edges", "csv", output=False) if run_args.ingest_invalid else None
    ignore_file = get_file(run_args, "ignore.edges", "csv", output=False)
    if not ignore_file.exists():
        ignore_file = None
    check_file(vertices_file, run_args, True)
    check_file(edges_file, run_args, True)
    if run_args.ingest_invalid:
        check_file(invalid_edges_file, run_args, True)
    logger.info("Reading graph input files")
    reader = RelationshipFileReader(vertices_file, edges_file, run_args.hops, run_args.filter, invalid_edges_file,
                                    run_args.invalid_filter, ignore_file)
    return reader


if __name__ == '__main__':
    args = init_argparse()
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG if args.verbose else logging.INFO)
    logger = logging.getLogger(__name__)
    run(args)
