#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import logging
import multiprocessing
import os
import tempfile

import json
from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope.utils import conllu
from pareidoscope.utils import cwb
from pareidoscope.utils import database
# from pareidoscope.utils import helper
from pareidoscope.utils import nx_graph
from pareidoscope.utils import statistics

logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)


def arguments():
    """"""
    parser = argparse.ArgumentParser(description='Run a batch of queries against a corpus')
    parser.add_argument("-f", "--format", choices=["conllu", "cwb", "db"], required=True, help="Input format of the corpus: Either a text-based format (CoNLL-U or CWB-treebank) or a database created by pareidoscope_corpus_to_sqlite")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def remove_auxiliary_attributes(graph, abcn):
    """Remove auxiliary attributes"""
    for l in [l for v, l in graph.nodes(data=True)] + [l for s, t, l in graph.edges(data=True)]:
        if abcn == "a" or abcn == "n":
            if "only_B" in l:
                for attribute in l["only_B"]:
                    del l[attribute]
        if abcn == "b" or abcn == "n":
            if "only_A" in l:
                for attribute in l["only_A"]:
                    del l[attribute]
        if abcn != "n":
            if "focus_point" in l:
                del l["focus_point"]
        if "only_A" in l:
            del l["only_A"]
        if "only_B" in l:
            del l["only_B"]
    for v, l in graph.nodes(data=True):
        del l["query"]
    return graph


def extract_graphs(query):
    """Extract query graphs for A, B and N"""
    a = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "A" or l["query"] == "AB"]).copy()
    b = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "B" or l["query"] == "AB"]).copy()
    n = query.subgraph([v for v, l in query.nodes(data=True) if l["query"] == "AB"]).copy()
    ga = remove_auxiliary_attributes(a, "a")
    gb = remove_auxiliary_attributes(b, "b")
    gn = remove_auxiliary_attributes(n, "n")
    gc = remove_auxiliary_attributes(query, "c")
    return [nx_graph.ensure_consecutive_vertices(g) for g in (gc, ga, gb, gn)]


def add_focus_point(query):
    """Search for choke point vertices"""
    gc, ga, gb, gn = query
    focus_point_vertex = None
    for v, l in gn.nodes(data=True):
        if "focus_point" in l:
            focus_point_vertex = v
            del l["focus_point"]
            break
    if focus_point_vertex is None:
        focus_point_vertex = nx_graph.get_choke_point(gn)
    assert focus_point_vertex is not None
    return gc, ga, gb, gn, focus_point_vertex


def write_results(prefix, results):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    counting_methods = [cm for cm in ("embeddings", "subgraphs", "focus_points", "sentences") if cm in results[0]]
    values = ("o11", "r1", "c1", "n", "inconsistent", "log_likelihood", "t_score", "dice")
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number"] + [":".join(_) for _ in (itertools.product(counting_methods, values))]
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            line = [str(i)] + [str(r[cm][v]) for cm, v in (itertools.product(counting_methods, values))]
            fh.write("\t".join(line) + "\n")


def main():
    """"""
    args = arguments()
    queries = pareidoscope.query.read_queries(args.QUERIES)
    queries = [extract_graphs(q) for q in queries]
    queries = [add_focus_point(q) for q in queries]
    results = [{} for q in queries]
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    logging.info("Using %d processes" % processes)
    # groupsize = 10 * 10 * processes
    with multiprocessing.Pool(processes=processes) as pool:
        logging.info("using %d cpus" % processes)
        if args.format == "db":
            conn, c = database.connect_to_database(args.CORPUS)
            for i, qline in enumerate(queries):
                logging.info("query no. %d" % i)
                gc, ga, gb, gn, choke_point = qline
                with tempfile.TemporaryFile() as fp:
                    sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(gn))
                    for s in sents:
                        fp.write((s + "\n").encode(encoding="utf-8"))
                    fp.seek(0)
                    sentences = (s.decode(encoding="utf-8") for s in fp)
                # sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(gn))
                # for sentences in helper.grouper_nofill(groupsize, sents):
                    query_args = (itertools.chain.from_iterable(a) for a in zip(itertools.repeat(qline), ((s,) for s in sentences)))
                    r = pool.imap_unordered(pareidoscope.query.run_queries_db, query_args, 10)
                    # for debugging, it is often better to avoid multiprocessing:
                    # r = map(pareidoscope.query.run_queries_db, query_args)
                    for result in r:
                        pareidoscope.query.merge_result_db(result, i, results)
                logging.info(results[i])
        else:
            with open(args.CORPUS, encoding="utf-8") as corpus:
                if args.format == "cwb":
                    sents = cwb.sentences_iter(corpus, return_id=False)
                elif args.format == "conllu":
                    sents = conllu.sentences_iter(corpus, return_id=False)
                sentences = sents
                # for i, sentences in enumerate(helper.grouper_nofill(groupsize, sents)):
                    # logging.info("processing sentences %d--%d" % (i * groupsize + 1, i * groupsize + min(groupsize, len(sentences))))
                r = pool.imap_unordered(pareidoscope.query.run_queries, zip(sentences, itertools.repeat(args.format), itertools.repeat(queries)), 10)
                # for debugging, it is often better to avoid multiprocessing:
                # r = map(pareidoscope.query.run_queries, zip(sentences, itertools.repeat(args.format), itertools.repeat(queries)))
                for result, sensible in r:
                    if sensible:
                        pareidoscope.query.merge_result(result, results)
                # end for
        for result in results:
            for counting_method, freq in result.items():
                o, e = statistics.get_contingency_table(freq["o11"], freq["r1"], freq["c1"], freq["n"])
                freq["log_likelihood"] = statistics.one_sided_log_likelihood(o, e)
                freq["t_score"] = statistics.t_score(o, e)
                freq["dice"] = statistics.dice(o, e)
        write_results(args.output, results)


if __name__ == "__main__":
    main()
