#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import json
import logging
import multiprocessing
import operator
import os

from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope import subgraph_enumeration
from pareidoscope.utils import database
from pareidoscope.utils import statistics

# logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)
logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.DEBUG)


class Sentinel:
    pass


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Perform a simple collexeme analysis, i.e. find words or lemmata associated with a specified slot in the provided linguistic structures")
    parser.add_argument("-c", "--collexeme", choices=["word", "lemma"], default="lemma", help="Should we look for collexemes at the word level or at the lemma level? Default: lemma")
    parser.add_argument("-o", "--output", type=os.path.abspath, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus as SQLite3 database")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def identify_focus_point(graph):
    """Search for focus point vertex"""
    focus_point_vertex = None
    for v, l in graph.nodes(data=True):
        if "collo_item" in l:
            focus_point_vertex = v
            del l["collo_item"]
            break
    assert focus_point_vertex is not None
    return graph, v


def build_focus_point_query(graph, focus_point):
    """"""
    pos_lexical = set(["word", "pos", "lemma", "wc", "root"])
    neg_lexical = set(["not_%s" % pl for pl in pos_lexical])
    sql_query = "SELECT count(*) FROM tokens"
    where, args = [], []
    for k, v in graph.nodes[focus_point].items():
        if k in pos_lexical:
            where.append("%s = ?" % k)
            if k == "root":
                args.append(v == "root")
            else:
                args.append(v)
        elif k in neg_lexical:
            k = k[4:]
            where.append("%s != ?" % k)
            if k == "root":
                args.append(v == "root")
            else:
                args.append(v)
    return sql_query, where, args


def get_sample_size(corpus, graph, focus_point):
    """Find out how many instances of the focus point there are in the
    corpus

    """
    conn, c = database.connect_to_database(corpus)
    sql_query, where, args = build_focus_point_query(graph, focus_point)
    if len(args) > 0:
        sql_query += " WHERE "
        sql_query += " AND ".join(where)
        sample_size = c.execute(sql_query, args).fetchall()[0][0]
    else:
        sample_size = c.execute(sql_query).fetchall()[0][0]
    conn.close()
    return sample_size


def get_marginal_frequencies_c1(corpus, o11, sql_query, sql_where, sql_args, collexeme):
    """"""
    conn, c = database.connect_to_database(corpus)
    c1 = {}
    local_sql_query = sql_query + " WHERE " + " AND ".join(sql_where + ["%s = ?" % collexeme])
    for colloitem in o11:
        c1[colloitem] = c.execute(local_sql_query, sql_args + [colloitem]).fetchall()[0][0]
    conn.close()
    return c1


def get_matches(args):
    """"""
    query_graph, target_graph, focus_point_vertex, word_or_lemma = args
    target_graph = json_graph.node_link_graph(json.loads(target_graph))
    matches = subgraph_enumeration.get_choke_point_matches(pareidoscope.query.strip_vid(query_graph), target_graph, focus_point_vertex)
    return [target_graph.nodes[match][word_or_lemma] for match in matches]


def write_results(prefix, results, word_or_lemma):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    values = ["o11", "r1", "c1", "n", "log_likelihood", "t_score", "dice"]
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number", word_or_lemma] + values
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            for c in r:
                line = [str(i), c[word_or_lemma]] + [str(c[v]) for v in values]
                fh.write("\t".join(line) + "\n")


def fill_input_queue(input_queue, corpus, graph, focus_point, collexeme, processes, sentinel):
    """"""
    conn, c = database.connect_to_database(corpus)
    sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(graph))
    for s in sents:
        input_queue.put((graph, s, focus_point, collexeme))
    conn.close()
    for proc in range(processes):
        input_queue.put(sentinel)


def process_input_queue(func, input_queue, output_queue, sentinel):
    """"""
    while True:
        data = input_queue.get()
        if isinstance(data, Sentinel):
            break
        result = func(data)
        output_queue.put(result)
    output_queue.put(sentinel)


def main():
    """"""
    args = arguments()
    queries = pareidoscope.query.read_queries(args.QUERIES)
    results = []
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    processes = max(1, processes - 1)
    sentinel = Sentinel()
    logging.info("Using %d + 1 processes" % processes)
    for i, query in enumerate(queries):
        logging.info("query no. %d" % i)
        graph, focus_point = identify_focus_point(query)
        n = get_sample_size(args.CORPUS, graph, focus_point)
        sql_query, sql_where, sql_args = build_focus_point_query(graph, focus_point)
        o11 = collections.Counter()
        input_queue = multiprocessing.Queue(maxsize=processes * 100)
        output_queue = multiprocessing.Queue(maxsize=processes * 100)
        producer = multiprocessing.Process(target=fill_input_queue, args=(input_queue, args.CORPUS, graph, focus_point, args.collexeme, processes, sentinel))
        with multiprocessing.Pool(processes=processes, initializer=process_input_queue, initargs=(get_matches, input_queue, output_queue, sentinel)):
            producer.start()
            observed_sentinels = 0
            while True:
                data = output_queue.get()
                if isinstance(data, Sentinel):
                    observed_sentinels += 1
                    if observed_sentinels == processes:
                        break
                    else:
                        continue
                o11.update(data)
            producer.join()
        r1 = sum(o11.values())
        c1 = get_marginal_frequencies_c1(args.CORPUS, o11, sql_query, sql_where, sql_args, args.collexeme)
        log_likelihood, t_score, dice = {}, {}, {}
        for item, freq in o11.items():
            o, e = statistics.get_contingency_table(freq, r1, c1[item], n)
            log_likelihood[item] = statistics.one_sided_log_likelihood(o, e)
            t_score[item] = statistics.t_score(o, e)
            dice[item] = statistics.dice(o, e)
        local_results = [{args.collexeme: item, "o11": o11[item], "r1": r1, "c1": c1[item], "n": n, "log_likelihood": ll, "t_score": t_score[item], "dice": dice[item]} for item, ll in sorted(log_likelihood.items(), key=operator.itemgetter(1, 0), reverse=True)]
        results.append(local_results)
    write_results(args.output, results, args.collexeme)
    logging.info("done")


if __name__ == "__main__":
    main()
