#!/usr/bin/env python3

import argparse
import itertools
import logging
import multiprocessing
import os
import statistics

from someweta import utils
from someweta import ASPTagger

logging.basicConfig(format='%(message)s', level=logging.DEBUG)


def arguments():
    """Process command line arguments."""
    parser = argparse.ArgumentParser(description="An averaged perceptron part-of-speech tagger")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--train", type=os.path.abspath, help="Train the tagger on the input corpus and write the model to the specified file")
    group.add_argument("--tag", type=os.path.abspath, help="Tag the input corpus using the specified model")
    group.add_argument("--evaluate", type=os.path.abspath, help="Evaluate the performance of the specified model on the input corpus")
    group.add_argument("--crossvalidate", action="store_true", help="Evaluate tagger performance via 10-fold cross-validation on the input corpus")
    parser.add_argument("--brown", type=argparse.FileType("r"), help="""Brown clusters (paths output file
                        produced by wcluster (https://github.com/percyliang/brown-cluster); optional and only for
                        training or cross-validation""")
    parser.add_argument("--w2v", type=argparse.FileType("r"), help="Word2Vec vectors; optional and only for training or cross-validation")
    parser.add_argument("--lexicon", type=os.path.abspath, help="Additional full-form lexicon; optional and only for training or cross-validation")
    parser.add_argument("--mapping", type=os.path.abspath, help="Additional mapping to coarser tagset; optional and only for tagging, evaluating or cross-validation")
    parser.add_argument("--prior", type=os.path.abspath, help="Prior weights, i.e. a model trained on another corpus; optional and only for training or cross-validation")
    parser.add_argument("-i", "--iterations", type=int, default=10, help="Only for training or cross-validation: Number of iterations; default: 10")
    parser.add_argument("-b", "--beam-size", type=int, default=5, help="Size of the search beam; default: 5")
    parser.add_argument("CORPUS", type=argparse.FileType("r"),
                        help="""Input corpus. Path to a file or "-" for STDIN. Format for training,
                             evaluation and cross-validation: One
                             token-pos pair per line, separated by a
                             tab; sentences delimited by an empty
                             line. Format for tagging: One token per
                             line; sentences delimited by an empty
                             line.""")
    return parser.parse_args()


def evaluate_fold(args):
    i, beam_size, iterations, lexicon, mapping, brown_clusters, word_to_vec, words, tags, lengths, sentence_ranges, div, mod = args
    asptagger = ASPTagger(beam_size, iterations, lexicon, mapping, brown_clusters, word_to_vec)
    test_ranges = sentence_ranges[i * div + min(i, mod):(i + 1) * div + min(i + 1, mod)]
    test_start = test_ranges[0][0]
    test_end = test_ranges[-1][0] + test_ranges[-1][1]
    test_lengths = lengths[i * div + min(i, mod):(i + 1) * div + min(i + 1, mod)]
    train_lengths = lengths[:i * div + min(i, mod)] + lengths[(i + 1) * div + min(i + 1, mod):]
    test_words = words[test_start:test_end]
    test_tags = tags[test_start:test_end]
    train_words = words[:test_start] + words[test_end:]
    train_tags = tags[:test_start] + tags[test_end:]
    asptagger.train(train_words, train_tags, train_lengths)
    accuracy, accuracy_iv, accuracy_oov, coarse_accuracy, coarse_accuracy_iv, coarse_accuracy_oov = asptagger.evaluate(test_words, test_tags, test_lengths)
    logging.info("Accuracy: %.2f%%" % (accuracy * 100,))
    if coarse_accuracy is not None:
        logging.info("Accuracy on mapped tagset: %.2f%%" % (coarse_accuracy * 100,))
    return accuracy, accuracy_iv, accuracy_oov, coarse_accuracy, coarse_accuracy_iv, coarse_accuracy_oov


def main():
    args = arguments()
    lexicon, mapping, brown_clusters, word_to_vec = None, None, None, None
    if args.mapping and (args.tag or args.evaluate or args.crossvalidate):
        mapping = utils.read_mapping(args.mapping)
    if args.lexicon and (args.train or args.crossvalidate):
        lexicon = utils.read_lexicon(args.lexicon)
    if args.brown and (args.train or args.crossvalidate):
        brown_clusters = utils.read_brown_clusters(args.brown)
    if args.w2v and (args.train or args.crossvalidate):
        word_to_vec = utils.read_word2vec_vectors(args.w2v)
    asptagger = ASPTagger(args.beam_size, args.iterations, lexicon, mapping, brown_clusters, word_to_vec)
    if args.prior and (args.train or args.crossvalidate):
        asptagger.load_prior_model(args.prior)
    if args.train:
        words, tags, lengths = utils.read_corpus(args.CORPUS, tagged=True)
        asptagger.train(words, tags, lengths)
        asptagger.save(args.train)
    elif args.tag:
        asptagger.load(args.tag)
        corpus = utils.iter_corpus(args.CORPUS, tagged=False)
        slc = list(itertools.islice(corpus, 100))
        while slc:
            words, lengths = zip(*slc)
            words = list(itertools.chain.from_iterable(words))
            tagged_corpus = asptagger.tag(words, lengths)
            for sentence in tagged_corpus:
                print("\n".join(["\t".join(t) for t in sentence]), "\n")
            slc = list(itertools.islice(corpus, 100))
    elif args.evaluate:
        asptagger.load(args.evaluate)
        words, tags, lengths = utils.read_corpus(args.CORPUS, tagged=True)
        accuracy, accuracy_iv, accuracy_oov, coarse_accuracy, coarse_accuracy_iv, coarse_accuracy_oov = asptagger.evaluate(words, tags, lengths)
        print("Accuracy: %.2f%%; IV: %.2f%%; OOV: %.2f%%" % (accuracy * 100, accuracy_iv * 100, accuracy_oov * 100))
        if coarse_accuracy is not None:
            print("Accuracy on mapped tagset: %.2f%%; IV: %.2f%%; OOV: %.2f%%" % (coarse_accuracy * 100, coarse_accuracy_iv * 100, coarse_accuracy_oov * 100))
    elif args.crossvalidate:
        words, tags, lengths = utils.read_corpus(args.CORPUS)
        sentence_ranges = list(zip((a - b for a, b in zip(itertools.accumulate(lengths), lengths)), lengths))
        div, mod = divmod(len(sentence_ranges), 10)
        with multiprocessing.Pool() as pool:
            accs = pool.map(evaluate_fold, zip(range(10), itertools.repeat(args.beam_size), itertools.repeat(args.iterations), itertools.repeat(lexicon), itertools.repeat(mapping), itertools.repeat(brown_clusters), itertools.repeat(word_to_vec), itertools.repeat(words), itertools.repeat(tags), itertools.repeat(lengths), itertools.repeat(sentence_ranges), itertools.repeat(div), itertools.repeat(mod)))
        accuracies, accuracies_iv, accuracies_oov, coarse_accuracies, coarse_accuracies_iv, coarse_accuracies_oov = zip(*accs)
        mean_accuracy = statistics.mean(accuracies)
        double_stdev = 2 * statistics.stdev(accuracies)
        print("Mean accuracy: %.2f%% ±%.2f" % (mean_accuracy * 100, double_stdev * 100))
        if coarse_accuracies[0] is not None:
            coarse_mean_accuracy = statistics.mean(coarse_accuracies)
            coarse_double_stdev = 2 * statistics.stdev(coarse_accuracies)
            print("Mean accuracy on mapped tagset: %.2f%% ±%.2f" % (coarse_mean_accuracy * 100, coarse_double_stdev * 100))


if __name__ == "__main__":
    main()
