#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import functools
import itertools
import json
import logging
import multiprocessing
import os
import tempfile

from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope import subgraph_isomorphism
from pareidoscope import subgraph_enumeration
from pareidoscope.utils import database
from pareidoscope.utils import statistics
from pareidoscope.utils import nx_graph

# logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)
logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.DEBUG)


Frequencies = collections.namedtuple("Frequencies", ["embeddings", "subgraphs", "focus_points", "sentences"])


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Find associated larger structures, i.e. structures with additional adjacent vertices.")
    parser.add_argument("--max-size", type=int, default=7, help="Maximal number of vertices in the larger structure. Default: 7")
    parser.add_argument("--min-coocc", type=int, default=5, help="Minimal number of sentences in which the larger structure has to cooccur with the query. Default: 5")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus as SQLite3 database")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def get_focus_point(query):
    """Search for choke point vertices"""
    focus_point_vertex = None
    for v, l in query.nodes(data=True):
        if "focus_point" in l:
            focus_point_vertex = v
            del l["focus_point"]
            break
    if focus_point_vertex is None:
        focus_point_vertex = nx_graph.get_choke_point(query)
    assert focus_point_vertex is not None
    return query, focus_point_vertex


def delexicalize(query):
    """Remove word and lemma attributes"""
    delexicalized = query.copy()
    delcount = 0
    for v, l in delexicalized.nodes(data=True):
        for attribute in ("word", "lemma"):
            if attribute in l:
                del l[attribute]
                delcount += 1
    assert delcount > 0
    return delexicalized


def get_cooccurring_structures(args):
    """"""
    query_graph, target_graph, focus_point_vertex, max_size = args
    target_graph = json_graph.node_link_graph(json.loads(target_graph))
    isomorphisms = subgraph_isomorphism.get_subgraph_isomorphisms_nx(pareidoscope.query.strip_vid(query_graph), target_graph)
    r1_counts = collections.defaultdict(set)
    mappings = collections.defaultdict(lambda: collections.defaultdict(set))
    o11 = {}
    o11_counts = collections.defaultdict(lambda: collections.defaultdict(set))
    for iso in isomorphisms:
        subgraph = frozenset(iso)
        focus_point = iso[focus_point_vertex]
        r1_counts["embeddings"].add(iso)
        r1_counts["subgraphs"].add(subgraph)
        r1_counts["focus_points"].add(focus_point)
        mappings["subgraph_to_embeddings"][subgraph].add(iso)
        mappings["embedding_to_focus_point"][iso] = focus_point
    mappings["subgraph_to_focus_points"] = {sg: set([mappings["embedding_to_focus_point"][e] for e in embs]) for sg, embs in mappings["subgraph_to_embeddings"].items()}
    for subg in r1_counts["subgraphs"]:
        neighbors = [set(target_graph.successors(v) + target_graph.predecessors(v)) for v in subg]
        nbunch = functools.reduce(lambda x, y: x.union(y), neighbors + [subg])
        induced_star = target_graph.subgraph(nbunch)
        bfo_graph, bfo_to_raw = subgraph_enumeration.get_bfo(induced_star, fragment=True)
        for sg in subgraph_enumeration.enumerate_csg_minmax(bfo_graph, bfo_to_raw, min_vertices=len(subg) + 1, max_vertices=min(induced_star.number_of_nodes(), max_size)):
            # all vertices from sg have to be in isg
            if not subg <= set(sg.nodes()):
                continue
            # all edges from query_graph have to be in isg
            embedding = next(iter(mappings["subgraph_to_embeddings"][subg]))
            if not all([sg.has_edge(embedding[s], embedding[t]) for s, t in query_graph.edges()]):
                continue
            subgraph_to_query = {embedding[v]: v for v in sorted(query_graph.nodes())}
            # delexicalize: remove all word, pos and lemma attributes unless they are in query
            for v in sg.nodes():
                if v in subgraph_to_query:
                    for attribute in list(sg.node[v]):
                        if attribute not in query_graph.node[subgraph_to_query[v]]:
                            del sg.node[v][attribute]
                else:
                    for attribute in ("word", "pos", "lemma", "root"):
                        if attribute in sg.node[v]:
                            del sg.node[v][attribute]
            gc = nx_graph.canonize(sg)
            gc_json = json.dumps(json_graph.node_link_data(gc), ensure_ascii=False, sort_keys=True)
            o11_counts[gc_json]["subgraphs"].add(subg)
            o11_counts[gc_json]["embeddings"] |= mappings["subgraph_to_embeddings"][subg]
            o11_counts[gc_json]["focus_points"] |= mappings["subgraph_to_focus_points"][subg]
            # for focus_point in mappings["subgraph_to_focus_points"][subg]:
            #     sg.node[focus_point]["focus_point"] = True
            #     gc = nx_graph.canonize(sg)
            #     gc_json = json.dumps(json_graph.node_link_data(gc), ensure_ascii=False, sort_keys=True)
            #     del sg.node[focus_point]["focus_point"]
            #     o11_counts[gc_json]["subgraphs"].add(subg)
            #     o11_counts[gc_json]["embeddings"] |= mappings["subgraph_to_embeddings"][subg]
            #     o11_counts[gc_json]["focus_points"] |= mappings["subgraph_to_focus_points"][subg]
    for gc_json, counts in o11_counts.items():
        o11[gc_json] = Frequencies(len(counts["embeddings"]), len(counts["subgraphs"]), len(counts["focus_points"]), 1)
    r1 = Frequencies(len(r1_counts["embeddings"]), len(r1_counts["subgraphs"]), len(r1_counts["focus_points"]), min(1, len(r1_counts["embeddings"])))
    return o11, r1


def derive_frequency_tuple(embeddings, focus_point_vertex):
    """"""
    subgraphs = set([frozenset(iso) for iso in embeddings])
    focus_points = set([iso[focus_point_vertex] for iso in embeddings])
    frequency_tuple = Frequencies(len(embeddings), len(subgraphs), len(focus_points), min(1, len(embeddings)))
    return subgraphs, focus_points, frequency_tuple


def get_frequencies(args):
    """"""
    gn, ga, gc_to_gb, gbs, gs, focus_point_vertex, max_size = args
    n_counts = collections.defaultdict(set)
    mappings = collections.defaultdict(lambda: collections.defaultdict(set))
    subsumed_by_ga = {}
    subsumed_by_gb = collections.defaultdict(dict)
    subsumed_by_gc = collections.defaultdict(dict)
    result = collections.defaultdict(dict)
    stripped_ga = pareidoscope.query.strip_vid(ga)
    target_graph = json_graph.node_link_graph(json.loads(gs))
    vs = set(target_graph.nodes())
    isomorphisms = subgraph_isomorphism.get_subgraph_isomorphisms_nx(pareidoscope.query.strip_vid(gn), target_graph)
    for iso in isomorphisms:
        subgraph = frozenset(iso)
        focus_point = iso[focus_point_vertex]
        n_counts["embeddings"].add(iso)
        n_counts["subgraphs"].add(subgraph)
        n_counts["focus_points"].add(focus_point)
        mappings["subgraph_to_embeddings"][subgraph].add(iso)
        # check if iso is instance of ga
        subsumed_by_ga[iso] = all((nx_graph.dictionary_match(qv, tv) for qv, tv in zip((l for v, l in sorted(stripped_ga.nodes(data=True))), (target_graph.node[v] for v in iso))))
        # if normal_cand_a is None:
        #     normal_cand_a = nx_graph.get_vertex_candidates(stripped_ga, target_graph)
        # vid_to_iso[iso] = {gn.node[qv]["vid"]: tv for qv, tv in zip(sorted(gn.nodes()), iso)}
        # vert_cand_a = pareidoscope.query._get_isomorphism_vertex_candidates(ga, normal_cand_a, vs, subgraph, vid_to_iso)
        # subsumed_by_ga[iso] = subgraph_enumeration.subsumes_nx(stripped_ga, target_graph, vertex_candidates=vert_cand_a)
    for subg in n_counts["subgraphs"]:
        neighbors = [set(target_graph.successors(v) + target_graph.predecessors(v)) for v in subg]
        nbunch = functools.reduce(lambda x, y: x.union(y), neighbors + [subg])
        induced_star = target_graph.subgraph(nbunch)
        bfo_graph, bfo_to_raw = subgraph_enumeration.get_bfo(induced_star, fragment=True)
        for sg in subgraph_enumeration.enumerate_csg_minmax(bfo_graph, bfo_to_raw, min_vertices=len(subg) + 1, max_vertices=min(induced_star.number_of_nodes(), max_size)):
            # all vertices from subg have to be in sg
            if not subg <= set(sg.nodes()):
                continue
            # all edges from gn have to be in sg
            embedding = next(iter(mappings["subgraph_to_embeddings"][subg]))
            if not all([sg.has_edge(embedding[s], embedding[t]) for s, t in gn.edges()]):
                continue
            subgraph_to_query = {embedding[v]: v for v in sorted(gn.nodes())}
            gc = sg.copy()
            # delexicalize: remove all word, pos and lemma attributes unless they are in query
            for v in gc.nodes():
                if v in subgraph_to_query:
                    for attribute in list(gc.node[v]):
                        if attribute not in ga.node[subgraph_to_query[v]]:
                            del gc.node[v][attribute]
                else:
                    for attribute in ("word", "pos", "lemma", "root"):
                        if attribute in gc.node[v]:
                            del gc.node[v][attribute]
            gc_json = json.dumps(json_graph.node_link_data(nx_graph.canonize(gc)), ensure_ascii=False, sort_keys=True)
            for iso in mappings["subgraph_to_embeddings"][subg]:
                mappings["gc_json_to_embedding_and_gc"][gc_json].add((iso, gc))
            gb = delexicalize(gc)
            gb_json = json.dumps(json_graph.node_link_data(nx_graph.canonize(gb)), ensure_ascii=False, sort_keys=True)
            # check if gb is in gb_to_gc
            if gb_json in gbs:
                # if yes: check if isomorphisms are instances of gb
                for iso in mappings["subgraph_to_embeddings"][subg]:
                    isomorphism_candidates = [set([v]) if v in subg else vs - subg for v in sorted(gb.nodes())]
                    consecutive_gb = nx_graph.ensure_consecutive_vertices(gb)
                    normal_cand_b = nx_graph.get_vertex_candidates(consecutive_gb, target_graph)
                    vert_cand_b = [a & b for a, b in zip(normal_cand_b, isomorphism_candidates)]
                    subsumed_by_gb[gb_json][iso] = subgraph_enumeration.subsumes_nx(consecutive_gb, target_graph, vertex_candidates=vert_cand_b)
    r1_embeddings = set([iso for iso, t in subsumed_by_ga.items() if t])
    r1_subgraphs, r1_focus_points, r1 = derive_frequency_tuple(r1_embeddings, focus_point_vertex)
    n = Frequencies(len(n_counts["embeddings"]), len(n_counts["subgraphs"]), len(n_counts["focus_points"]), min(1, len(n_counts["embeddings"])))
    for gc, gb in gc_to_gb.items():
        c1_embeddings = set([iso for iso, t in subsumed_by_gb.get(gb, {}).items() if t])
        c1_subgraphs, c1_focus_points, c1 = derive_frequency_tuple(c1_embeddings, focus_point_vertex)
        if gc in mappings["gc_json_to_embedding_and_gc"]:
            for iso, g in mappings["gc_json_to_embedding_and_gc"][gc]:
                subg = set(iso)
                isomorphism_candidates = [set([v]) if v in subg else vs - subg for v in sorted(g.nodes())]
                consecutive_gc = nx_graph.ensure_consecutive_vertices(g)
                normal_cand_c = nx_graph.get_vertex_candidates(consecutive_gc, target_graph)
                vert_cand_c = [a & b for a, b in zip(normal_cand_c, isomorphism_candidates)]
                subsumed_by_gc[gc][iso] = subgraph_enumeration.subsumes_nx(consecutive_gc, target_graph, vertex_candidates=vert_cand_c)
        o11_embeddings = set([iso for iso, t in subsumed_by_gc.get(gc, {}).items() if t])
        o11_subgraphs, o11_focus_points, o11 = derive_frequency_tuple(o11_embeddings, focus_point_vertex)
        inconsistent_embeddings = (r1_embeddings & c1_embeddings) - o11_embeddings
        inconsistent_subgraphs = (r1_subgraphs & c1_subgraphs) - o11_subgraphs
        inconsistent_focus_points = (r1_focus_points & c1_focus_points) - o11_focus_points
        inconsistent_sentence = 1 if r1.sentences == 1 and c1.sentences == 1 and o11.sentences == 0 else 0
        inconsistencies = Frequencies(len(inconsistent_embeddings), len(inconsistent_subgraphs), len(inconsistent_focus_points), inconsistent_sentence)
        result[gc]["embeddings"] = (o11.embeddings, r1.embeddings - inconsistencies.embeddings / 2, c1.embeddings - inconsistencies.embeddings / 2, n.embeddings, inconsistencies.embeddings)
        result[gc]["subgraphs"] = (o11.subgraphs, r1.subgraphs - inconsistencies.subgraphs / 2, c1.subgraphs - inconsistencies.subgraphs / 2, n.subgraphs, inconsistencies.subgraphs)
        result[gc]["focus_points"] = (o11.focus_points, r1.focus_points - inconsistencies.focus_points / 2, c1.focus_points - inconsistencies.focus_points / 2, n.focus_points, inconsistencies.focus_points)
        result[gc]["sentences"] = (o11.sentences, r1.sentences - inconsistencies.sentences / 2, c1.sentences - inconsistencies.sentences / 2, n.sentences, inconsistencies.sentences)
    return result


def write_results(prefix, results):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
    values = ("o11", "r1", "c1", "n", "inconsistent", "log_likelihood", "t_score", "dice")
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number", "larger_structure"] + [":".join(_) for _ in (itertools.product(counting_methods, values))]
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            for coocc in r:
                line = [str(i), coocc["larger_structure"]] + [str(coocc[cm][v]) for cm, v in (itertools.product(counting_methods, values))]
                fh.write("\t".join(line) + "\n")


def main():
    """"""
    args = arguments()
    results = []
    counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
    conn, c = database.connect_to_database(args.CORPUS)
    queries = pareidoscope.query.read_queries(args.QUERIES)
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    logging.info("Using %d processes" % processes)
    with multiprocessing.Pool(processes=processes) as pool:
        for i, query in enumerate(queries):
            r1 = Frequencies(0, 0, 0, 0)
            o11 = collections.defaultdict(lambda: Frequencies(0, 0, 0, 0))
            frequencies = collections.defaultdict(lambda: collections.defaultdict(lambda: (0, 0, 0, 0, 0)))
            logging.info("query no. %d" % i)
            graph, focus_point = get_focus_point(query)
            delexicalized = delexicalize(graph)
            logging.info("Collect cooccurring larger structures")
            with tempfile.TemporaryFile() as fp:
                sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(graph))
                for s in sents:
                    fp.write((s + "\n").encode(encoding="utf-8"))
                fp.seek(0)
                sentences = (s.decode(encoding="utf-8").rstrip() for s in fp)
                query_args = zip(itertools.repeat(graph), sentences, itertools.repeat(focus_point), itertools.repeat(args.max_size))
                r = pool.imap_unordered(get_cooccurring_structures, query_args, 10)
                # r = map(get_cooccurring_structures, query_args)
                for local_o11, local_r1 in r:
                    r1 = [sum(_) for _ in zip(r1, local_r1)]
                    for g, freq in local_o11.items():
                        o11[g] = [sum(_) for _ in zip(o11[g], freq)]
            gc_to_gb = {}
            gbs = set()
            for gc, freq in o11.items():
                if freq[3] < args.min_coocc:
                    continue
                gb = json_graph.node_link_graph(json.loads(gc))
                gb = delexicalize(gb)
                gb = nx_graph.canonize(gb)
                gb = json.dumps(json_graph.node_link_data(gb), ensure_ascii=False, sort_keys=True)
                gc_to_gb[gc] = gb
                gbs.add(gb)
            logging.info("Determine association strengths")
            with tempfile.TemporaryFile() as fp:
                sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(delexicalized))
                for s in sents:
                    fp.write((s + "\n").encode(encoding="utf-8"))
                fp.seek(0)
                sentences = (s.decode(encoding="utf-8").rstrip() for s in fp)
                # query_args = zip(itertools.repeat(delexicalized), sentences, itertools.repeat(focus_point), itertools.repeat(args.max_size), itertools.repeat(larger_structures))
                # r = pool.imap_unordered(get_c1_and_n, query_args, 10)
                # r = map(get_c1_and_n, query_args)
                query_args = zip(itertools.repeat(delexicalized), itertools.repeat(graph), itertools.repeat(gc_to_gb), itertools.repeat(gbs), sentences, itertools.repeat(focus_point), itertools.repeat(args.max_size))
                r = pool.imap_unordered(get_frequencies, query_args, 10)
                # r = map(get_frequencies, query_args)
                for res in r:
                    for gc, freqs in res.items():
                        for cm in counting_methods:
                            frequencies[gc][cm] = [sum(_) for _ in zip(frequencies[gc][cm], freqs[cm])]
            local_result = {}
            for gc in gc_to_gb:
                assert frequencies[gc]["embeddings"][0] == o11[gc][0]
                assert frequencies[gc]["embeddings"][1] + frequencies[gc]["embeddings"][4] / 2 == r1[0]
                assert frequencies[gc]["subgraphs"][0] == o11[gc][1]
                assert frequencies[gc]["subgraphs"][1] + frequencies[gc]["subgraphs"][4] / 2 == r1[1]
                assert frequencies[gc]["focus_points"][0] == o11[gc][2]
                assert frequencies[gc]["focus_points"][1] + frequencies[gc]["focus_points"][4] / 2 == r1[2]
                assert frequencies[gc]["sentences"][0] == o11[gc][3]
                assert frequencies[gc]["sentences"][1] + frequencies[gc]["sentences"][4] / 2 == r1[3]
                if frequencies[gc]["sentences"][0] < args.min_coocc:
                    continue
                local_result[gc] = {"larger_structure": gc}
                for i, cm in enumerate(counting_methods):
                    lo11, lr1, lc1, ln, linc = frequencies[gc][cm]
                    o, e = statistics.get_contingency_table(lo11, lr1, lc1, ln)
                    log_likelihood = statistics.one_sided_log_likelihood(o, e)
                    t_score = statistics.t_score(o, e)
                    dice = statistics.dice(o, e)
                    local_result[gc][cm] = {"o11": lo11, "r1": lr1, "c1": lc1, "n": ln, "inconsistent": linc, "log_likelihood": log_likelihood, "t_score": t_score, "dice": dice}
            sorted_gcs = sorted(local_result.keys(), key=lambda x: (local_result[x]["focus_points"]["log_likelihood"], x), reverse=True)
            results.append([local_result[gc] for gc in sorted_gcs])
        write_results(args.output, results)
    logging.info("done")


if __name__ == "__main__":
    main()
