#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import itertools
import json
import logging
import multiprocessing
import operator
import os
import tempfile

from networkx.readwrite import json_graph

import pareidoscope.query
from pareidoscope import subgraph_enumeration
from pareidoscope.utils import database
from pareidoscope.utils import statistics

# logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)
logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.DEBUG)


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Perform a simple collexeme analysis, i.e. find words or lemmata associated with a specified slot in the provided linguistic structures")
    parser.add_argument("-c", "--collexeme", choices=["word", "lemma"], default="lemma", help="Should we look for collexemes at the word level or at the lemma level? Default: lemma")
    parser.add_argument("-o", "--output", type=os.path.abspath, required=True, help="Output prefix")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=os.path.abspath, help="Input corpus as SQLite3 database")
    parser.add_argument("QUERIES", type=argparse.FileType("r", encoding="utf-8"), help="Queries file as JSON list")
    return parser.parse_args()


def identify_focus_point(graph):
    """Search for focus point vertex"""
    focus_point_vertex = None
    for v, l in graph.nodes(data=True):
        if "collo_item" in l:
            focus_point_vertex = v
            del l["collo_item"]
            break
    assert focus_point_vertex is not None
    return graph, v


def build_focus_point_query(graph, focus_point):
    """"""
    pos_lexical = set(["word", "pos", "lemma", "wc", "root"])
    neg_lexical = set(["not_%s" % pl for pl in pos_lexical])
    sql_query = "SELECT count(*) FROM tokens"
    where, args = [], []
    for k, v in graph.node[focus_point].items():
        if k in pos_lexical:
            where.append("%s = ?" % k)
            if k == "root":
                args.append(v == "root")
            else:
                args.append(v)
        elif k in neg_lexical:
            k = k[4:]
            where.append("%s != ?" % k)
            if k == "root":
                args.append(v == "root")
            else:
                args.append(v)
    return sql_query, where, args


def get_sample_size(c, graph, focus_point):
    """Find out how many instances of the focus point there are in the
    corpus

    """
    sql_query, where, args = build_focus_point_query(graph, focus_point)
    if len(args) > 0:
        sql_query += " WHERE "
        sql_query += " AND ".join(where)
        return c.execute(sql_query, args).fetchall()[0][0]
    else:
        return c.execute(sql_query).fetchall()[0][0]


def get_matches(args):
    """"""
    query_graph, target_graph, focus_point_vertex, word_or_lemma = args
    target_graph = json_graph.node_link_graph(json.loads(target_graph))
    matches = subgraph_enumeration.get_choke_point_matches(pareidoscope.query.strip_vid(query_graph), target_graph, focus_point_vertex)
    return [target_graph.node[match][word_or_lemma] for match in matches]


def write_results(prefix, results, word_or_lemma):
    """Write results to files

    Arguments:
    - `prefix`:
    - `results`:
    """
    values = ["o11", "r1", "c1", "n", "log_likelihood", "t_score", "dice"]
    with open("%s.tsv" % prefix, "w") as fh:
        header = ["query_number", word_or_lemma] + values
        fh.write("\t".join(header) + "\n")
        for i, r in enumerate(results):
            for c in r:
                line = [str(i), c[word_or_lemma]] + [str(c[v]) for v in values]
                fh.write("\t".join(line) + "\n")


def main():
    """"""
    args = arguments()
    conn, c = database.connect_to_database(args.CORPUS)
    queries = pareidoscope.query.read_queries(args.QUERIES)
    results = []
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    logging.info("Using %d processes" % processes)
    with multiprocessing.Pool(processes=processes) as pool:
        for i, query in enumerate(queries):
            logging.info("query no. %d" % i)
            graph, focus_point = identify_focus_point(query)
            n = get_sample_size(c, graph, focus_point)
            sql_query, sql_where, sql_args = build_focus_point_query(graph, focus_point)
            o11 = None
            with tempfile.TemporaryFile() as fp:
                sents = database.sentence_candidates(c, pareidoscope.query.strip_vid(graph))
                for s in sents:
                    fp.write((s + "\n").encode(encoding="utf-8"))
                fp.seek(0)
                sentences = (s.decode(encoding="utf-8").rstrip() for s in fp)
                query_args = zip(itertools.repeat(graph), sentences, itertools.repeat(focus_point), itertools.repeat(args.collexeme))
                r = pool.imap_unordered(get_matches, query_args, 10)
                # for debugging, it is often better to avoid multiprocessing:
                # r = map(get_matches, query_args)
                o11 = collections.Counter(itertools.chain.from_iterable(r))
            r1 = sum(o11.values())
            c1 = {}
            local_sql_query = sql_query + " WHERE " + " AND ".join(sql_where + ["%s = ?" % args.collexeme])
            for colloitem in o11:
                c1[colloitem] = c.execute(local_sql_query, sql_args + [colloitem]).fetchall()[0][0]
            log_likelihood, t_score, dice = {}, {}, {}
            for item, freq in o11.items():
                o, e = statistics.get_contingency_table(freq, r1, c1[item], n)
                log_likelihood[item] = statistics.one_sided_log_likelihood(o, e)
                t_score[item] = statistics.t_score(o, e)
                dice[item] = statistics.dice(o, e)
            local_results = [{args.collexeme: item, "o11": o11[item], "r1": r1, "c1": c1[item], "n": n, "log_likelihood": ll, "t_score": t_score[item], "dice": dice[item]} for item, ll in sorted(log_likelihood.items(), key=operator.itemgetter(1, 0), reverse=True)]
            results.append(local_results)
    write_results(args.output, results, args.collexeme)
    logging.info("done")


if __name__ == "__main__":
    main()
