#!/usr/bin/env python
# -*- coding: utf-8; mode: python; -*-

"""Script for evaluating the agreement between two forests of segmentation trees.

Given two folders, each containing a forest of segmentation trees, several
metrics are calculated:
 * PK (Beeferman Berger 1999)
 * Windiff (Pevzner Hearst 2002)
 * Fleiss PI boundary edit distance (Fournier 2013)
 * Krippendorff Alpha Unitizing, untyped (Krippendorff 1999)
 * Krippendorff Alpha Unitizing, typed (Krippendorff 2004)
 * Unlabelled parseval F1 score
 * Labelled parseval F1 score

This file provides helper methods for preparing and passing the data. All
credits for the implementations of the metrics go to the authors of the
used packages:
 - segeval
 - dkpro agreement
 - parseval

@author = Andreas Peldszus
@mail = <peldszus at uni dash potsdam dot de>
@version = 0.1.0

"""


##################################################################
# Imports
from dsegmenter.evaluation.align import align_tokenized_tree, AlignmentError
from dsegmenter.evaluation.metrics import (
    metric_pk, metric_windiff, metric_pi_bed, metric_f1, metric_lf1,
    metric_alpha_unit, metric_alpha_unit_untyped, metric_kappa, avg, sigma)
from dsegmenter.evaluation.segmentation import get_confusions

from nltk import ConfusionMatrix
from nltk.tree import Tree
import argparse
import os
import re
import string


##################################################################
# Variables and Constants
BAD_CHAR_RE = re.compile('[_]')


##################################################################
# Methods
def analyze_confusions(forest1, forest2):
    confusions = []
    for tree1, tree2 in zip(forest1, forest2):
        confusions.extend(list(get_confusions(tree1, tree2)))

    matching_spans = [p for p in confusions if None not in p]
    mismatching_spans_all = [p for p in confusions if None in p]
    mismatching_spans_with_gold = [(a, b) for a, b in confusions if b is None]
    matching_categories = [(a, b) for a, b in matching_spans if a == b]

    print '\n### Span matching'
    print 'Matching spans:', len(matching_spans)
    print 'Mismatching spans all:', len(mismatching_spans_all)
    print 'Mismatching spans with gold:', len(mismatching_spans_with_gold)
    print 'Ratio of matching spans with all mismatches: %.2f percent' % \
        (100.0 * len(matching_spans) / len(matching_spans + mismatching_spans_all))
    print 'Ratio of matching spans with mismatches to gold: %.2f percent' % \
        (100.0 * len(matching_spans) / len(matching_spans + mismatching_spans_with_gold))
    print 'Matching spans and categories:', len(matching_categories)

    print '\n### Confusions of categories'
    cm = ConfusionMatrix([a for a, b in confusions],
                         [b for a, b in confusions])
    print(cm.pretty_format(sort_by_count=True))

    print '\n### Categorical agreement'
    k, e, o = metric_kappa([a for a, b in confusions],
                           [b for a, b in confusions])
    print 'kappa %.4f (EA %.4f, OA %.4f) for all span pairs' % (k, e, o)

    gold_vs_pred_spans = matching_spans + mismatching_spans_with_gold
    k, e, o = metric_kappa([a for a, b in gold_vs_pred_spans],
                           [b for a, b in gold_vs_pred_spans])
    print 'kappa %.4f (EA %.4f, OA %.4f) for gold_vs_pred span pairs' % (k, e, o)

    k, e, o = metric_kappa([a for a, b in matching_spans],
                           [b for a, b in matching_spans])
    print 'kappa %.4f (EA %.4f, OA %.4f) for matching span pairs' % (k, e, o)


def read_files(filename1, filename2, enc="utf-8"):
    # read files
    input_1 = open(filename1).read().decode(enc)
    input_1 = BAD_CHAR_RE.sub(' ', input_1)
    input_2 = open(filename2).read().decode(enc)
    input_2 = BAD_CHAR_RE.sub(' ', input_2)

    # tokenize them
    tok_1 = string.split(input_1)
    tok_2 = string.split(input_2)

    # align tokenizations, if impossible, return None
    try:
        error, aligned_tok_1, aligned_tok_2 = align_tokenized_tree(
            tok_1, tok_2, tree_pair_name=os.path.basename(filename1))
    except AlignmentError as err:
        print unicode(err).encode("utf-8")
        return None, None

    # join again
    aligned_1 = ' '.join(aligned_tok_1)
    aligned_2 = ' '.join(aligned_tok_2)

    # parse trees
    tree1 = Tree.fromstring(aligned_1)
    tree2 = Tree.fromstring(aligned_2)
    return tree1, tree2


def load_trees(dir1, dir2, suffix):
    # check input folders
    a1_id = os.path.basename(os.path.abspath(dir1))
    a2_id = os.path.basename(os.path.abspath(dir2))
    assert os.path.isdir(dir1), dir1 + " is not a directory."
    assert os.path.isdir(dir2), dir2 + " is not a directory."
    assert a1_id != a2_id, "first_dir and second_dir identical"

    for basename in sorted(os.listdir(dir1)):
        # skip irrelevant files
        if not basename.endswith(suffix):
            continue
        fname2 = os.path.join(dir2, basename)
        if not os.path.isfile(fname2):
            print "Couldn't find corresponding file %s" % fname2
            continue
        fname1 = os.path.join(dir1, basename)
        # read segments from both files
        tree1, tree2 = read_files(fname1, fname2)
        if tree1 is None or tree2 is None:
            print "Couldn't align files %s and %s: Skipping." % (fname1, fname2)
            continue
        else:
            yield basename, tree1, tree2


def main():
    # initialize argument parser
    aparser = argparse.ArgumentParser(
        description=("Script for computing agreement measures for two forests "
                     "of segmentation trees, given as two separate "
                     "directories."))
    aparser.add_argument(
        "dir1",
        help="input dir containing the 1st or gold set of segmentation trees")
    aparser.add_argument(
        "dir2",
        help="input dir containing the 2nd or pred set of segmentation trees")
    aparser.add_argument(
        "--delexicalize",
        help="replace all tokens by a tok-symbol after aligning",
        action='store_true')
    aparser.add_argument(
        "-s", "--suffix",
        help="the suffix of files to look for", default='.tree')
    args = aparser.parse_args()

    anno1 = os.path.basename(os.path.abspath(args.dir1))
    anno2 = os.path.basename(os.path.abspath(args.dir2))
    print "Calculating agreement for annotator pair <%s,%s>" % (anno1, anno2)

    corpus = {
        text: (tree1, tree2)
        for text, tree1, tree2 in load_trees(args.dir1, args.dir2, args.suffix)
    }
    texts_order = sorted(corpus)

    forest1, forest2 = zip(*corpus.values())
    analyze_confusions(forest1, forest2)

    metrics = {
        'pk': metric_pk,
        'windiff': metric_windiff,
        'pi_bed': metric_pi_bed,
        'f1': metric_f1,
        'lf1': metric_lf1,
        'aU+c': metric_alpha_unit,
        'aU-c': metric_alpha_unit_untyped
    }
    metrics_order = ['pk', 'windiff', 'pi_bed', 'aU-c', 'aU+c', 'f1', 'lf1']

    print '\n### Segmentation scores'
    scores = {metric: {text: None for text in corpus} for metric in metrics}
    print '\t'.join(['{:<22s}'.format('text')] +
                    ['{:>s}'.format(metric) for metric in metrics_order])
    # text-wise scores
    for text in texts_order:
        tree1, tree2 = corpus[text]
        for metric, metric_func in metrics.iteritems():
            scores[metric][text] = metric_func([tree1], [tree2])
        print '\t'.join(['{:<22s}'.format(text)] +
                        ['{:>3.2f}'.format(scores[metric][text])
                         for metric in metrics_order])
    # average
    print '\t'.join(['{:<22s}'.format('average')] +
                    ['{:>3.2f}'.format(avg(scores[metric].values()))
                     for metric in metrics_order])
    # standard deviation
    print '\t'.join(['{:<22s}'.format('std')] +
                    ['{:>3.2f}'.format(sigma(scores[metric].values()))
                     for metric in metrics_order])
    # on full corpus
    for metric, metric_func in metrics.iteritems():
        scores[metric]['_ALL_'] = metric_func(forest1, forest2)
    print '\t'.join(['{:<22s}'.format('full corpus')] +
                    ['{:>3.2f}'.format(scores[metric]['_ALL_'])
                        for metric in metrics_order])

##################################################################
# Main
if __name__ == "__main__":
    main()
