#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import logging
import multiprocessing
import os
import itertools
import json
import tempfile

from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope import subgraph_isomorphism
from pareidoscope.utils import database
from pareidoscope.utils import statistics
from pareidoscope.utils import nx_graph

# logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)
logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.DEBUG)


Frequencies = collections.namedtuple("Frequencies", ["embeddings", "subgraphs", "focus_points", "sentences"])


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Perform a covarying collexeme analysis, i.e. find cooccurring words within a linguistic structure. If the structure is a single dependency relation, this is equivalent to relational cooccurrences.")
    parser.add_argument("-c", "--collexeme", choices=["word", "lemma"], default="lemma", help="Should we look for collexemes at the word level or at the lemma level? Default: lemma")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus as SQLite3 database")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def identify_collo_items(graph):
    """Search for collo_A and collo_B"""
    collo_a, collo_b, focus_point = None, None, None
    for v, l in graph.nodes(data=True):
        if "collo_A" in l:
            collo_a = v
            del l["collo_A"]
        if "collo_B" in l:
            collo_b = v
            del l["collo_B"]
        if "focus_point" in l:
            focus_point = v
            del l["focus_point"]
    if focus_point is None:
        focus_point = nx_graph.get_choke_point(graph)
    assert collo_a is not None and collo_b is not None and focus_point is not None
    return graph, focus_point, collo_a, collo_b


def get_cooccurrences(args):
    """"""
    query_graph, target_graph, focus_point, collo_a, collo_b, word_or_lemma = args
    pairs = {}
    inconsistencies = collections.defaultdict(lambda: (0, 0, 0, 0))
    embeddings = collections.defaultdict(int)
    subgraphs = collections.defaultdict(set)
    focus_points = collections.defaultdict(set)
    subgraph_to_pairs = collections.defaultdict(set)
    focus_point_to_pairs = collections.defaultdict(set)
    sentences = set()
    target_graph = json_graph.node_link_graph(json.loads(target_graph))
    isomorphisms = subgraph_isomorphism.get_subgraph_isomorphisms_nx(pareidoscope.query.strip_vid(query_graph), target_graph)
    total_subgraphs, total_focus_points = set(), set()
    for iso in isomorphisms:
        item_a = target_graph.node[iso[collo_a]][word_or_lemma]
        item_b = target_graph.node[iso[collo_b]][word_or_lemma]
        pair = (item_a, item_b)
        embeddings[pair] += 1
        subgraph = frozenset(iso)
        subgraphs[pair].add(subgraph)
        subgraph_to_pairs[subgraph].add(pair)
        total_subgraphs.add(subgraph)
        fp = iso[focus_point]
        focus_points[pair].add(fp)
        focus_point_to_pairs[fp].add(pair)
        total_focus_points.add(fp)
        sentences.add(pair)
    for pair in sentences:
        pairs[pair] = Frequencies(embeddings[pair], len(subgraphs[pair]), len(focus_points[pair]), 1)
    sample_sizes = Frequencies(sum(embeddings.values()), len(total_subgraphs), len(total_focus_points), min(1, len(sentences)))
    # determine possible inconsistencies
    for pair in itertools.product(*[set(_) for _ in zip(*sentences)]):
        if pair not in sentences:
            inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 0, 0, 1))]
    for focus_point, ps in focus_point_to_pairs.items():
        for pair in itertools.product(*[set(_) for _ in zip(*ps)]):
            if pair not in ps:
                inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 0, 1, 0))]
    for subgraph, ps in subgraph_to_pairs.items():
        for pair in itertools.product(*[set(_) for _ in zip(*ps)]):
            if pair not in ps:
                inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 1, 0, 0))]
    return pairs, sample_sizes, dict(inconsistencies)


def write_results(prefix, results, word_or_lemma):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
    values = ("o11", "r1", "c1", "n", "inconsistent", "log_likelihood", "t_score", "dice")
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number", "%s_A" % word_or_lemma, "%s_B" % word_or_lemma] + [":".join(_) for _ in (itertools.product(counting_methods, values))]
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            for coocc in r:
                line = [str(i), coocc["%s_A" % word_or_lemma], coocc["%s_B" % word_or_lemma]] + [str(coocc[cm][v]) for cm, v in (itertools.product(counting_methods, values))]
                fh.write("\t".join(line) + "\n")


def main():
    """"""
    args = arguments()
    results = []
    conn, c = database.connect_to_database(args.CORPUS)
    queries = pareidoscope.query.read_queries(args.QUERIES)
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    logging.info("Using %d processes" % processes)
    with multiprocessing.Pool(processes=processes) as pool:
        for i, query in enumerate(queries):
            pairs = collections.defaultdict(lambda: (0, 0, 0, 0))
            marginals_a = collections.defaultdict(lambda: (0, 0, 0, 0))
            marginals_b = collections.defaultdict(lambda: (0, 0, 0, 0))
            inconsistencies = collections.defaultdict(lambda: (0, 0, 0, 0))
            sample_sizes = (0, 0, 0, 0)
            logging.info("query no. %d" % i)
            graph, focus_point, collo_a, collo_b = identify_collo_items(query)
            with tempfile.TemporaryFile() as fp:
                sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(graph))
                for s in sents:
                    fp.write((s + "\n").encode(encoding="utf-8"))
                fp.seek(0)
                sentences = (s.decode(encoding="utf-8").rstrip() for s in fp)
                query_args = zip(itertools.repeat(graph), sentences, itertools.repeat(focus_point), itertools.repeat(collo_a), itertools.repeat(collo_b), itertools.repeat(args.collexeme))
                r = pool.imap_unordered(get_cooccurrences, query_args, 10)
                for ps, sam_siz, inc in r:
                    sample_sizes = [sum(_) for _ in zip(sample_sizes, sam_siz)]
                    for pair, freqs in ps.items():
                        pairs[pair] = [sum(_) for _ in zip(pairs[pair], freqs)]
                    for pair, freqs in inc.items():
                        inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], freqs)]
            for pair, freqs in pairs.items():
                item_a, item_b = pair
                marginals_a[item_a] = [sum(_) for _ in zip(marginals_a[item_a], freqs)]
                marginals_b[item_b] = [sum(_) for _ in zip(marginals_b[item_b], freqs)]
            local_result = {}
            for pair, freq in pairs.items():
                item_a, item_b = pair
                local_result[pair] = {"%s_A" % args.collexeme: item_a, "%s_B" % args.collexeme: item_b}
                frequencies = zip(freq, marginals_a[item_a], marginals_b[item_b], sample_sizes)
                counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
                for cm, f, inc in zip(counting_methods, frequencies, inconsistencies[pair]):
                    o11, r1, c1, n = f
                    if inc > 0:
                        r1 -= inc / 2
                        c1 -= inc / 2
                    o, e = statistics.get_contingency_table(o11, r1, c1, n)
                    log_likelihood = statistics.one_sided_log_likelihood(o, e)
                    t_score = statistics.t_score(o, e)
                    dice = statistics.dice(o, e)
                    local_result[pair][cm] = {"o11": o11, "r1": r1, "c1": c1, "n": n, "inconsistent": inc, "log_likelihood": log_likelihood, "t_score": t_score, "dice": dice}
            sorted_pairs = sorted(local_result.keys(), key=lambda x: (local_result[x]["focus_points"]["log_likelihood"], x), reverse=True)
            results.append([local_result[p] for p in sorted_pairs])
        write_results(args.output, results, args.collexeme)
    logging.info("done")


if __name__ == "__main__":
    main()
