#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import logging
import multiprocessing
import os

import pareidoscope.query
from pareidoscope.utils import conllu
from pareidoscope.utils import cwb
from pareidoscope.utils import database
# from pareidoscope.utils import helper
from pareidoscope.utils import nx_graph
from pareidoscope.utils import statistics

logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)


class Sentinel:
    pass


def arguments():
    """"""
    parser = argparse.ArgumentParser(description='Run a batch of queries against a corpus')
    parser.add_argument("-f", "--format", choices=["conllu", "cwb", "db"], required=True, help="Input format of the corpus: Either a text-based format (CoNLL-U or CWB-treebank) or a database created by pareidoscope_corpus_to_sqlite")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def remove_auxiliary_attributes(graph, abcn):
    """Remove auxiliary attributes"""
    for l in [l for v, l in graph.nodes(data=True)] + [l for s, t, l in graph.edges(data=True)]:
        if abcn == "a" or abcn == "n":
            if "only_B" in l:
                for attribute in l["only_B"]:
                    del l[attribute]
        if abcn == "b" or abcn == "n":
            if "only_A" in l:
                for attribute in l["only_A"]:
                    del l[attribute]
        if abcn != "n":
            if "focus_point" in l:
                del l["focus_point"]
        if "only_A" in l:
            del l["only_A"]
        if "only_B" in l:
            del l["only_B"]
    for v, l in graph.nodes(data=True):
        del l["query"]
    return graph


def extract_graphs(query):
    """Extract query graphs for A, B and N"""
    a = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "A" or l["query"] == "AB"]).copy()
    b = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "B" or l["query"] == "AB"]).copy()
    n = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "AB"]).copy()
    ga = remove_auxiliary_attributes(a, "a")
    gb = remove_auxiliary_attributes(b, "b")
    gn = remove_auxiliary_attributes(n, "n")
    gc = remove_auxiliary_attributes(query, "c")
    return [nx_graph.ensure_consecutive_vertices(g) for g in (gc, ga, gb, gn)]


def add_focus_point(query):
    """Search for choke point vertices"""
    gc, ga, gb, gn = query
    focus_point_vertex = None
    for v, l in gn.nodes(data=True):
        if "focus_point" in l:
            focus_point_vertex = v
            del l["focus_point"]
            break
    if focus_point_vertex is None:
        focus_point_vertex = nx_graph.get_choke_point(gn)
    assert focus_point_vertex is not None
    return gc, ga, gb, gn, focus_point_vertex


def write_results(prefix, results):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    counting_methods = [cm for cm in ("embeddings", "subgraphs", "focus_points", "sentences") if cm in results[0]]
    values = ("o11", "r1", "c1", "n", "inconsistent", "log_likelihood", "t_score", "dice")
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number"] + [":".join(_) for _ in (itertools.product(counting_methods, values))]
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            line = [str(i)] + [str(r[cm][v]) for cm, v in (itertools.product(counting_methods, values))]
            fh.write("\t".join(line) + "\n")


def process_input_queue(func, input_queue, output_queue, sentinel):
    """"""
    while True:
        data = input_queue.get()
        if isinstance(data, Sentinel):
            break
        result = func(data)
        output_queue.put(result)
    output_queue.put(sentinel)


def fill_input_queue_db(input_queue, corpus, query, processes, sentinel):
    """"""
    conn, c = database.connect_to_database(corpus)
    gc, ga, gb, gn, choke_point = query
    sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(gn))
    for s in sents:
        input_queue.put((gc, ga, gb, gn, choke_point, s))
    conn.close()
    for proc in range(processes):
        input_queue.put(sentinel)


def fill_input_queue_text(input_queue, corpus, corpus_format, queries, processes, sentinel):
    """"""
    with open(corpus, encoding="utf-8") as corpus:
        if corpus_format == "cwb":
            sents = cwb.sentences_iter(corpus, return_id=False)
        elif corpus_format == "conllu":
            sents = conllu.sentences_iter(corpus, return_id=False)
        for s in sents:
            input_queue.put((s, corpus_format, queries))
    for proc in range(processes):
        input_queue.put(sentinel)


def association_strength_db(args, queries, processes, sentinel):
    """"""
    results = [{} for q in queries]
    for i, query in enumerate(queries):
        logging.info("query no. %d" % i)
        gc, ga, gb, gn, choke_point = query
        input_queue = multiprocessing.Queue(maxsize=processes * 100)
        output_queue = multiprocessing.Queue(maxsize=processes * 100)
        producer = multiprocessing.Process(target=fill_input_queue_db, args=(input_queue, args.CORPUS, query, processes, sentinel))
        with multiprocessing.Pool(processes=processes, initializer=process_input_queue, initargs=(pareidoscope.query.run_queries_db, input_queue, output_queue, sentinel)):
            producer.start()
            observed_sentinels = 0
            while True:
                data = output_queue.get()
                if isinstance(data, Sentinel):
                    observed_sentinels += 1
                    if observed_sentinels == processes:
                        break
                    else:
                        continue
                pareidoscope.query.merge_result_db(data, i, results)
            producer.join()
        logging.info(results[i])
    return results


def association_strength_text(args, queries, processes, sentinel):
    """"""
    results = [{} for q in queries]
    input_queue = multiprocessing.Queue(maxsize=processes * 100)
    output_queue = multiprocessing.Queue(maxsize=processes * 100)
    producer = multiprocessing.Process(target=fill_input_queue_text, args=(input_queue, args.CORPUS, args.format, queries, processes, sentinel))
    with multiprocessing.Pool(processes=processes, initializer=process_input_queue, initargs=(pareidoscope.query.run_queries, input_queue, output_queue, sentinel)):
        producer.start()
        observed_sentinels = 0
        while True:
            data = output_queue.get()
            if isinstance(data, Sentinel):
                observed_sentinels += 1
                if observed_sentinels == processes:
                    break
                else:
                    continue
            result, sensible = data
            if sensible:
                pareidoscope.query.merge_result(result, results)
        producer.join()
    return results


def main():
    """"""
    args = arguments()
    queries = pareidoscope.query.read_queries(args.QUERIES)
    queries = [extract_graphs(q) for q in queries]
    queries = [add_focus_point(q) for q in queries]
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    processes = max(1, processes - 1)
    sentinel = Sentinel()
    logging.info("Using %d + 1 processes" % processes)
    if args.format == "db":
        results = association_strength_db(args, queries, processes, sentinel)
    else:
        results = association_strength_text(args, queries, processes, sentinel)
    for result in results:
        for counting_method, freq in result.items():
            o, e = statistics.get_contingency_table(freq["o11"], freq["r1"], freq["c1"], freq["n"])
            freq["log_likelihood"] = statistics.one_sided_log_likelihood(o, e)
            freq["t_score"] = statistics.t_score(o, e)
            freq["dice"] = statistics.dice(o, e)
    write_results(args.output, results)


if __name__ == "__main__":
    main()
