#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import itertools
import json
import logging
import multiprocessing
import os

from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope import subgraph_isomorphism
from pareidoscope.utils import database
from pareidoscope.utils import statistics
from pareidoscope.utils import nx_graph

# logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)
logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.DEBUG)


Frequencies = collections.namedtuple("Frequencies", ["embeddings", "subgraphs", "focus_points", "sentences"])


class Sentinel:
    pass


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Perform a covarying collexeme analysis, i.e. find cooccurring words within a linguistic structure. If the structure is a single dependency relation, this is equivalent to relational cooccurrences.")
    parser.add_argument("-c", "--collexeme", choices=["word", "lemma"], default="lemma", help="Should we look for collexemes at the word level or at the lemma level? Default: lemma")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus as SQLite3 database")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def identify_collo_items(graph):
    """Search for collo_A and collo_B"""
    collo_a, collo_b, focus_point = None, None, None
    for v, l in graph.nodes(data=True):
        if "collo_A" in l:
            collo_a = v
            del l["collo_A"]
        if "collo_B" in l:
            collo_b = v
            del l["collo_B"]
        if "focus_point" in l:
            focus_point = v
            del l["focus_point"]
    if focus_point is None:
        focus_point = nx_graph.get_choke_point(graph)
    assert collo_a is not None and collo_b is not None and focus_point is not None
    return graph, focus_point, collo_a, collo_b


def get_cooccurrences(args):
    """"""
    query_graph, target_graph, focus_point, collo_a, collo_b, word_or_lemma = args
    pairs = {}
    inconsistencies = collections.defaultdict(lambda: (0, 0, 0, 0))
    embeddings = collections.defaultdict(int)
    subgraphs = collections.defaultdict(set)
    focus_points = collections.defaultdict(set)
    subgraph_to_pairs = collections.defaultdict(set)
    focus_point_to_pairs = collections.defaultdict(set)
    sentences = set()
    target_graph = json_graph.node_link_graph(json.loads(target_graph))
    isomorphisms = subgraph_isomorphism.get_subgraph_isomorphisms_nx(pareidoscope.query.strip_vid(query_graph), target_graph)
    total_subgraphs, total_focus_points = set(), set()
    for iso in isomorphisms:
        item_a = target_graph.nodes[iso[collo_a]][word_or_lemma]
        item_b = target_graph.nodes[iso[collo_b]][word_or_lemma]
        pair = (item_a, item_b)
        embeddings[pair] += 1
        subgraph = frozenset(iso)
        subgraphs[pair].add(subgraph)
        subgraph_to_pairs[subgraph].add(pair)
        total_subgraphs.add(subgraph)
        fp = iso[focus_point]
        focus_points[pair].add(fp)
        focus_point_to_pairs[fp].add(pair)
        total_focus_points.add(fp)
        sentences.add(pair)
    for pair in sentences:
        pairs[pair] = Frequencies(embeddings[pair], len(subgraphs[pair]), len(focus_points[pair]), 1)
    sample_sizes = Frequencies(sum(embeddings.values()), len(total_subgraphs), len(total_focus_points), min(1, len(sentences)))
    # determine possible inconsistencies
    for pair in itertools.product(*[set(_) for _ in zip(*sentences)]):
        if pair not in sentences:
            inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 0, 0, 1))]
    for focus_point, ps in focus_point_to_pairs.items():
        for pair in itertools.product(*[set(_) for _ in zip(*ps)]):
            if pair not in ps:
                inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 0, 1, 0))]
    for subgraph, ps in subgraph_to_pairs.items():
        for pair in itertools.product(*[set(_) for _ in zip(*ps)]):
            if pair not in ps:
                inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], (0, 1, 0, 0))]
    return pairs, sample_sizes, dict(inconsistencies)


def write_results(prefix, results, word_or_lemma):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
    values = ("o11", "r1", "c1", "n", "inconsistent", "log_likelihood", "t_score", "dice")
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number", "%s_A" % word_or_lemma, "%s_B" % word_or_lemma] + [":".join(_) for _ in (itertools.product(counting_methods, values))]
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            for coocc in r:
                line = [str(i), coocc["%s_A" % word_or_lemma], coocc["%s_B" % word_or_lemma]] + [str(coocc[cm][v]) for cm, v in (itertools.product(counting_methods, values))]
                fh.write("\t".join(line) + "\n")


def fill_input_queue(input_queue, corpus, graph, focus_point, collo_a, collo_b, collexeme, processes, sentinel):
    """"""
    conn, c = database.connect_to_database(corpus)
    sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(graph))
    for s in sents:
        input_queue.put((graph, s, focus_point, collo_a, collo_b, collexeme))
    conn.close()
    for proc in range(processes):
        input_queue.put(sentinel)


def process_input_queue(func, input_queue, output_queue, sentinel):
    """"""
    while True:
        data = input_queue.get()
        if isinstance(data, Sentinel):
            break
        result = func(data)
        output_queue.put(result)
    output_queue.put(sentinel)


def main():
    """"""
    args = arguments()
    results = []
    queries = pareidoscope.query.read_queries(args.QUERIES)
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    processes = max(1, processes - 1)
    sentinel = Sentinel()
    logging.info("Using %d + 1 processes" % processes)
    for i, query in enumerate(queries):
        logging.info("query no. %d" % i)
        pairs = collections.defaultdict(lambda: (0, 0, 0, 0))
        marginals_a = collections.defaultdict(lambda: (0, 0, 0, 0))
        marginals_b = collections.defaultdict(lambda: (0, 0, 0, 0))
        inconsistencies = collections.defaultdict(lambda: (0, 0, 0, 0))
        sample_sizes = (0, 0, 0, 0)
        graph, focus_point, collo_a, collo_b = identify_collo_items(query)
        input_queue = multiprocessing.Queue(maxsize=processes * 100)
        output_queue = multiprocessing.Queue(maxsize=processes * 100)
        producer = multiprocessing.Process(target=fill_input_queue, args=(input_queue, args.CORPUS, graph, focus_point, collo_a, collo_b, args.collexeme, processes, sentinel))
        with multiprocessing.Pool(processes=processes, initializer=process_input_queue, initargs=(get_cooccurrences, input_queue, output_queue, sentinel)):
            producer.start()
            observed_sentinels = 0
            while True:
                data = output_queue.get()
                if isinstance(data, Sentinel):
                    observed_sentinels += 1
                    if observed_sentinels == processes:
                        break
                    else:
                        continue
                ps, sam_siz, inc = data
                sample_sizes = [sum(_) for _ in zip(sample_sizes, sam_siz)]
                for pair, freqs in ps.items():
                    pairs[pair] = [sum(_) for _ in zip(pairs[pair], freqs)]
                for pair, freqs in inc.items():
                    inconsistencies[pair] = [sum(_) for _ in zip(inconsistencies[pair], freqs)]
            producer.join()
        for pair, freqs in pairs.items():
            item_a, item_b = pair
            marginals_a[item_a] = [sum(_) for _ in zip(marginals_a[item_a], freqs)]
            marginals_b[item_b] = [sum(_) for _ in zip(marginals_b[item_b], freqs)]
        local_result = {}
        for pair, freq in pairs.items():
            item_a, item_b = pair
            local_result[pair] = {"%s_A" % args.collexeme: item_a, "%s_B" % args.collexeme: item_b}
            frequencies = zip(freq, marginals_a[item_a], marginals_b[item_b], sample_sizes)
            counting_methods = ("embeddings", "subgraphs", "focus_points", "sentences")
            for cm, f, inc in zip(counting_methods, frequencies, inconsistencies[pair]):
                o11, r1, c1, n = f
                if inc > 0:
                    r1 -= inc / 2
                    c1 -= inc / 2
                o, e = statistics.get_contingency_table(o11, r1, c1, n)
                log_likelihood = statistics.one_sided_log_likelihood(o, e)
                t_score = statistics.t_score(o, e)
                dice = statistics.dice(o, e)
                local_result[pair][cm] = {"o11": o11, "r1": r1, "c1": c1, "n": n, "inconsistent": inc, "log_likelihood": log_likelihood, "t_score": t_score, "dice": dice}
        sorted_pairs = sorted(local_result.keys(), key=lambda x: (local_result[x]["focus_points"]["log_likelihood"], x), reverse=True)
        results.append([local_result[p] for p in sorted_pairs])
    write_results(args.output, results, args.collexeme)
    logging.info("done")


if __name__ == "__main__":
    main()
