#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import json
import logging
import multiprocessing
import os

from networkx.readwrite import json_graph

from pareidoscope.utils import conllu
from pareidoscope.utils import cwb
from pareidoscope.utils import database
from pareidoscope.utils import nx_graph

logging.basicConfig(format="%(levelname)s %(asctime)s: %(message)s", level=logging.INFO)


class Sentinel:
    pass


def arguments():
    """"""
    parser = argparse.ArgumentParser(description="Convert a corpus in CoNLL-U or CWB-treebank format into a corresponding SQLite database")
    parser.add_argument("--db", type=os.path.abspath, required=True, help="SQLite3 database for results")
    parser.add_argument("--no-id", action="store_true", help="Corpus has no unique sentence IDs, create them on the fly")
    parser.add_argument("-f", "--format", choices=["conllu", "cwb"], required=True, help="Input format of the corpus: CoNLL-U or CWB-treebank")
    parser.add_argument("-p", "--cpu", type=int, default=25, help="Percentage of CPUs to use (0-100; default: 25)")
    parser.add_argument("CORPUS", type=argparse.FileType("r", encoding="utf-8"), help="The input corpus")
    args = parser.parse_args()
    return args


def sentence_to_graph(args):
    """"""
    sentence, origid, corpus_format = args
    if corpus_format == "cwb":
        create_digraph = nx_graph.create_nx_digraph_from_cwb
    elif corpus_format == "conllu":
        create_digraph = nx_graph.create_nx_digraph_from_conllu
    gs = create_digraph(sentence, origid)
    sensible = nx_graph.is_sensible_graph(gs)
    graph = ""
    if sensible:
        graph = json.dumps(json_graph.node_link_data(gs), ensure_ascii=False, sort_keys=True)
    return gs, graph, origid, sensible


def fill_input_queue(input_queue, sentences, corpus_format, processes, sentinel):
    """"""
    for (s, sid), cf in zip(sentences, itertools.repeat(corpus_format)):
        input_queue.put((s, sid, cf))
    for proc in range(processes):
        input_queue.put(sentinel)


def process_input_queue(func, input_queue, output_queue, sentinel):
    """"""
    while True:
        data = input_queue.get()
        if isinstance(data, Sentinel):
            break
        result = func(data)
        output_queue.put(result)
    output_queue.put(sentinel)


def consume_output_queue(output_queue, db, processes, sentinel):
    """"""
    conn, c = database.create_db(db)
    observed_sentinels = 0
    while True:
        data = output_queue.get()
        if isinstance(data, Sentinel):
            observed_sentinels += 1
            if observed_sentinels == processes:
                break
            else:
                continue
        gs, graph, origid, sensible = data
        if sensible:
            database.insert_sentence(c, origid, gs, graph)
    conn.commit()
    conn.close()


def main():
    """"""
    args = arguments()
    cpu_count = multiprocessing.cpu_count()
    processes = min(max(1, int(cpu_count * args.cpu / 100)), cpu_count)
    processes = max(1, processes - 2)
    sentinel = Sentinel()
    if args.format == "cwb":
        sents = cwb.sentences_iter(args.CORPUS, return_id=True)
    elif args.format == "conllu":
        sents = conllu.sentences_iter(args.CORPUS, return_id=True)
    if args.no_id:
        sents = ((s, "s-%d" % i) for i, (s, _) in zip(itertools.count(1), sents))
    input_queue = multiprocessing.Queue(maxsize=processes * 100)
    output_queue = multiprocessing.Queue(maxsize=processes * 100)
    producer = multiprocessing.Process(target=fill_input_queue, args=(input_queue, sents, args.format, processes, sentinel))
    consumer = multiprocessing.Process(target=consume_output_queue, args=(output_queue, args.db, processes, sentinel))
    with multiprocessing.Pool(processes=processes, initializer=process_input_queue, initargs=(sentence_to_graph, input_queue, output_queue, sentinel)):
        producer.start()
        consumer.start()
        producer.join()
        consumer.join()


if __name__ == "__main__":
    main()
