#!/usr/bin/env python
# -*- mode: python; coding: utf-8; -*-

"""Parse input text into elementary discourse segments.

USAGE:
discourse_segmenter [GLOBAL_OPTIONS] type [TYPE_SPECIFIC_OPTIONS] [FILEs]

"""

##################################################################
# Imports
from __future__ import absolute_import, print_function, unicode_literals

from dsegmenter.bparseg import BparSegmenter, CTree
from dsegmenter.bparseg import read_tok_trees as bpar_read_tok_trees
from dsegmenter.bparseg import read_trees as bpar_read_trees
from dsegmenter.bparseg import trees2segs as bpar_trees2segs
from dsegmenter.common import read_segments
from dsegmenter.edseg import EDSSegmenter, CONLL
from dsegmenter.mateseg import MateSegmenter
from dsegmenter.mateseg import read_tok_trees as mate_read_tok_trees
from dsegmenter.mateseg import read_trees as mate_read_trees
from dsegmenter.mateseg import trees2segs as mate_trees2segs

from collections import defaultdict
from sklearn.model_selection import KFold
import argparse
import codecs
import glob
import sys
import os
import re

##################################################################
# Constants and Variables
DEFAULT_ENCODING = "utf-8"
ENCODING = DEFAULT_ENCODING
EDSEG = "edseg"
BPARSEG = "bparseg"
MATESEG = "mateseg"
CV = "cv"
TEST = "test"
TRAIN = "train"
SEGMENT = "segment"
Segmenter = None
N_FOLDS = 10


##################################################################
# Methods
def _set_train_test_args(a_parser):
    """Add CLI options to ArgumentParser instance.

    Args:
      a_parser (argparse.ArgumentParser): ArgumentParser instance to which new
        arguments should be added

    Returns:
      void:

    """
    a_parser.add_argument("--tree-sfx",
                          help="suffix of the file names with input trees",
                          type=str, default="")
    a_parser.add_argument("--seg-sfx",
                          help="suffix of the names of segmentation files",
                          type=str, default="")
    a_parser.add_argument("tree_dir",
                          help="directory containing files with input trees",
                          type=str)
    a_parser.add_argument("seg_dir",
                          help="directory containing segmentation files",
                          type=str)


def _set_bpar_mate_args(a_parser, a_name, a_in_name, a_dflt_path):
    """Add CLI options common to BitPar and Mate segmenter.

    Args:
      a_parser (argparse.ArgumentParser): ArgumentParser instance to which new
        arguments should be added
      a_name (str): name of the segmenter for which new options should be added
      a_in_name (str): name of the input data structure
      a_dflt_path (str): path to default model

    Returns:
      void:

    Note:
      modifies ``a_parser`` in place

    """
    parser = a_parser.add_parser(a_name,
                                 help="machine-learning discourse"
                                 "segmenter for {:s}".format(a_in_name))
    subparsers = parser.add_subparsers(help="action to perform",
                                       dest="mode")
    parser_train = subparsers.add_parser(TRAIN,
                                         help="train new model"
                                         " on {:s} and segment"
                                         " files".format(a_in_name))
    parser_train.add_argument("model",
                              help="path to file in which to store the"
                              " trained model", type=str)
    _set_train_test_args(parser_train)

    parser_cv = subparsers.add_parser(CV,
                                      help="train and evaluate model"
                                      " using cross-validation")
    parser_cv.add_argument("-o", "--output-dir",
                           help="output directory (leave empty for"
                           " no output)", type=str, default="")
    parser_cv.add_argument("model",
                           help="path to the file in which the best"
                           " trained model should be stored",
                           type=str)
    _set_train_test_args(parser_cv)

    parser_test = subparsers.add_parser(TEST,
                                        help="test model on {:s}"
                                        " and segment files")
    parser_test.add_argument("-m", "--model",
                             help="path to file containing model",
                             type=str,
                             default=a_dflt_path)
    _set_train_test_args(parser_test)

    parser_segment = subparsers.add_parser(SEGMENT,
                                           help="split {:s} into"
                                           " discourse units".format(
                                               a_in_name))
    parser_segment.add_argument("-m", "--model",
                                help="path to file containing model",
                                type=str,
                                default=a_dflt_path)
    parser_segment.add_argument("files", help="input files",
                                nargs='*', metavar="file")


def _read_files(a_files, a_encoding=DEFAULT_ENCODING, a_skip_line=""):
    """Return iterator over lines of the input file.

    Args:
      a_files (list): files to read from
      a_encoding (str): text encoding used for input/output
      a_skip_line (str): line which should be skipped during iteration

    Yields:
      input lines

    """
    if not a_files:
        for line in sys.stdin:
            line = line.decode(a_encoding)
            if line == a_skip_line:
                print(line.encode(a_encoding))
            else:
                yield line.rstrip()
    else:
        for fname in a_files:
            with codecs.open(fname,
                             encoding=a_encoding, errors="replace") as ifile:
                for line in ifile:
                    if line == a_skip_line:
                        print(line.encode(a_encoding))
                    else:
                        yield line.rstrip()


def _align_files(a_tree_dir, a_seg_dir, a_tree_sfx, a_seg_sfx):
    """Align BitPar and segment files in two directories.

    Args:
      a_tree_dir (str): directory containing files with BitPar trees
      a_seg_dir (str): directory containing files with discourse segments
      a_tree_sfx (str): suffix of the names of BitPar files
      a_seg_sfx (str): suffix of the names of segmentation files

    Yields:
      2-tuples with BitPar and segment file

    """
    segf = ""
    basefname = ""
    BP_SFX_RE = re.compile(re.escape(a_tree_sfx) + '$')
    bpar_files = glob.iglob(os.path.join(a_tree_dir, '*' + a_tree_sfx))
    for bpf in bpar_files:
        basefname = BP_SFX_RE.sub("", os.path.basename(bpf))
        segf = os.path.join(a_seg_dir, basefname + a_seg_sfx)
        if os.path.isfile(segf) and os.access(segf, os.R_OK):
            yield (bpf, segf)
        else:
            print(
                "WARNING: No counterpart file found for BitPar"
                " file '{:s}'.".format(bpf), file=sys.stderr)


def _read_trees_segments(a_tree_dir, a_seg_dir, a_tree_sfx, a_seg_sfx,
                         a_read_tok_trees=bpar_read_tok_trees,
                         a_trees2segs=bpar_trees2segs,
                         a_fname2item=False,
                         a_encoding=DEFAULT_ENCODING):
    """Read input files containing discourse segments and BitPar trees.

    Args:
      a_tree_dir (str): directory containing files with BitPar trees
      a_seg_dir (str): directory containing files with discourse segments
      a_tree_sfx (str): suffix of the names of BitPar files
      a_seg_sfx (str): suffix of the names of segmentation files
      a_read_tok_trees (lambda): custom function for reading syntax trees
      a_trees2segs (lambda): custom function for aligning syntax trees with
        segments
      a_fname2item (bool): generate mappings from filenames to trees
      a_encoding (str): text encoding used for input/output

    Returns:
      tuple: list of segments and a list of trees

    """
    if a_fname2item:
        trees = defaultdict(list)
        segments = defaultdict(list)
    else:
        trees = []
        segments = []
    ts, segs = trees, segments
    tree2seg = {}
    toks2trees = {}
    toks2segs = {}
    tree_seg_files = _align_files(a_tree_dir, a_seg_dir, a_tree_sfx, a_seg_sfx)
    # do tree/segment alignment
    for tf, segf in tree_seg_files:
        if a_fname2item:
            ts, segs = trees[tf], segments[segf]
        with codecs.open(tf, 'r', encoding=a_encoding) as itf:
            toks2trees, _ = a_read_tok_trees(itf)
        with codecs.open(segf, 'r', encoding=a_encoding) as isegf:
            toks2segs = read_segments(isegf)
        if a_trees2segs is None:
            pass
        else:
            tree2seg = a_trees2segs(toks2trees, toks2segs)
            for t, s in tree2seg.iteritems():
                ts.append(t)
                segs.append(s)
    return (trees, segments)


def _output_segment_forrest(a_forrest, a_segmenter, a_output, a_encoding):
    """Split CONLL sentences in elementary discourse units and output them.

    Args:
      a_forrest (dsegmenter.edseg.CONLL): forrest of CoNLL trees
      a_segmenter (Segmenter): pointer to discourse segmenter
      a_output (bool): flag indicating whether dependency trees
        should be printed
      a_encoding (str): text encoding used for output

    Returns:
      void:

    """
    if a_forrest.is_empty():
        return
    else:
        if a_output:
            print(unicode(a_forrest).encode(a_encoding))
        sds_list = [a_segmenter.segment(sent) for sent in a_forrest]
        for sds in sds_list:
            sds.pretty_print(a_encoding=a_encoding)
        a_forrest.clear()


def edseg_segment(a_ilines, a_output_trees, a_encoding=DEFAULT_ENCODING):
    """Perform rule-based segmentation of CONLL dependency trees.

    Args:
      a_ilines (iterable): iterator over input lines
      a_output_trees (bool): flag indicating whether dependency trees
        should be printed
      a_encoding (str): text encoding used for input/output

    Returns:
     void:

    """
    forrest = CONLL()
    segmenter = EDSSegmenter()
    for line in a_ilines:
        if not line:
            # print collected sentences
            _output_segment_forrest(forrest, segmenter, a_output_trees,
                                    a_encoding)
            # output line
            print(line.encode(a_encoding))
        # otherwise, append the line to the CONLL forrest
        else:
            forrest.add_line(line)
    # output remained EDUs
    _output_segment_forrest(forrest, segmenter, a_output_trees, a_encoding)


def bpar_mate_segment(a_segmenter, a_lines, a_encoding=DEFAULT_ENCODING,
                      a_ostream=sys.stdout, a_read_trees=CTree.parse_lines):
    """Perform segmentation based on Mate dependency trees.

    Args:
      a_segmenter (MateSegmenter): pointer to Mate segmenter
      a_fnames (iterable): iterator over input lines
      a_encoding (str): text encoding used for input/output
      a_ostream (IOstream): output stream
      a_read_trees (func): custom function for reading trees

    Returns:
      void:

    """
    segments = None
    for itree in a_read_trees(a_lines):
        segments = a_segmenter.segment([itree])
        print('\n'.join([unicode(s[-1])
                         for s in segments]).encode(a_encoding),
              file=a_ostream)


def bpar_mate_test(a_segmenter, a_trees, a_segments):
    """Evaluate performance of segment classification.

    Args:
     a_segmenter (Segmenter): pointer to BitPar segmenter
      a_trees (list[dsegmenter.bparseg.CTree]): list of BitPar trees
      a_segments (list[dsegmenter.treeseg.DiscourseSegment]):
        list of discourse segments corresponding to BitPar trees

    Returns:
     void:

    """
    macro_f1, micro_f1 = a_segmenter.test(a_trees, a_segments)
    print("Macro F1-score: {:.2%}".format(macro_f1), file=sys.stderr)
    print("Micro F1-score: {:.2%}".format(micro_f1), file=sys.stderr)


def _cnt_stat(a_gold_segs, a_pred_segs):
    """Estimate true positives, false positives, and false negatives

    Args:
      a_gold_segs (list[dsegmenter.treeseg.DiscourseSegment]):
        gold segments
      a_pred_segs  (list[dsegmenter.treeseg.DiscourseSegment]):
        predicted segments

    Returns:
      tuple: true positives, false positives, and false negatives

    """
    tp = fp = fn = 0
    for gs, ps in zip(a_gold_segs, a_pred_segs):
        gs = gs.lower()
        ps = ps.lower()
        if gs == "none":
            if ps != "none":
                fp += 1
        elif gs == ps:
            tp += 1
        else:
            fn += 1
    return tp, fp, fn


def crossval(a_segmenter, a_path, a_fname2trees, a_fname2segs,
             a_tree_sfx, a_seg_sfx,
             a_output=False, a_out_dir=".", a_out_sfx=".tree",
             a_folds=N_FOLDS, a_encoding=ENCODING, a_read_trees=None):
    """Train and evaluate model using n-fold cross-validation.

    Args:
      a_segmenter (Segmenter): pointer to untrained segmenter instance
      a_path (str): path in which to store the model
      a_fname2trees (dict): mapping from file names to trees
      a_fname2segs (dict): mapping from file names to segments
      a_tree_sfx (str): suffix of the names of parse files
      a_seg_sfx (str): suffix of the names of segmentation files
      a_output (bool): boolean flag indicating whether output files should
        be produced
      a_out_dir (str): directory for writing output files
      a_out_sfx (str): suffix which should be appended to output files
      a_folds (int): number of folds
      a_encoding (str): default output encoding
      a_read_trees (func): method for reading input trees

    Returns:
      tuple: lists of macro F-scores, micro F-scores, and F1_{tp,fp}

    """
    # do necessary imports
    from sklearn.metrics import precision_recall_fscore_support
    from sklearn.externals import joblib
    import numpy as np

    # check conditions
    assert len(a_fname2trees) == len(a_fname2segs), \
        "Unmatching number of files with trees and segments."
    # make file names in `a_fname2trees` and `a_fname2segs` uniform and convert
    # segment classes to strings
    seg_sfx_re = re.compile(re.escape(a_seg_sfx) + '$')
    a_fname2segs = {seg_sfx_re.sub("", os.path.basename(k)) + a_out_sfx:
                    [str(iseg) for iseg in v]
                    for k, v in a_fname2segs.iteritems()}
    tree_sfx_re = re.compile(re.escape(a_tree_sfx) + '$')
    ofname2ifname = {tree_sfx_re.sub("", os.path.basename(k)) + a_out_sfx: k
                     for k in a_fname2trees}
    a_fname2trees = {tree_sfx_re.sub("", os.path.basename(k)) + a_out_sfx: v
                     for k, v in a_fname2trees.iteritems()}
    # estimate the number of and generate folds
    fnames = a_fname2trees.keys()
    n_fnames = len(fnames)
    if n_fnames < 2:
        print("Insufficient number of samples for"
              " cross-validation: {:d}.".format(n_fnames),
              file=sys.stderr)
        return -1
    kf = KFold(n_splits=min(n_fnames, a_folds))
    # generate features for trees
    fname2feats = {fname: [a_segmenter.featgen(*t)
                           if isinstance(t, tuple)
                           else a_segmenter.featgen(t)
                           for t in trees]
                   for fname, trees in a_fname2trees.iteritems()}
    # initialize auxiliary variables
    F1_tpfp = 0.
    macro_f1 = 0.
    macro_F1s = []
    micro_f1 = 0.
    micro_F1s = []
    best_macro_f1 = float("-inf")
    best_i = -1                 # index of the best run
    istart = ilen = 0
    trees = []
    pred_segs = []
    out_fnames = []
    fname2range = {}
    # fname2gld_pred = {}
    processed_fnames = {}
    in_fname = test_fname = out_fname = ""
    train_feats = train_segs = None
    test_feats = []
    test_segs = []
    tp = fp = fn = tp_i = fp_i = fn_i = 0
    # iterate over folds
    for i, (train, test) in enumerate(kf.split(fnames)):
        print("Fold: {:d}".format(i), file=sys.stderr)
        train_feats = [feat for k in train for feat in fname2feats[fnames[k]]]
        train_segs = [seg for k in train for seg in a_fname2segs[fnames[k]]]
        istart = 0
        for k in test:
            ilen = len(fname2feats[fnames[k]])
            fname2range[fnames[k]] = [istart, istart + ilen]
            istart += ilen
            test_feats += fname2feats[fnames[k]]
            test_segs += a_fname2segs[fnames[k]]
        # train classifier model
        a_segmenter.model.fit(train_feats, train_segs)
        # obtain new predictions
        pred_segs = a_segmenter.model.predict(test_feats)
        # update statistics and F1 scores
        tp_i, fp_i, fn_i = _cnt_stat(test_segs, pred_segs)
        tp += tp_i
        fp += fp_i
        fn += fn_i
        _, _, macro_f1, _ = precision_recall_fscore_support(test_segs,
                                                            pred_segs,
                                                            average='macro',
                                                            pos_label=None)
        _, _, micro_f1, _ = precision_recall_fscore_support(test_segs,
                                                            pred_segs,
                                                            average='micro',
                                                            pos_label=None)
        macro_F1s.append(macro_f1)
        micro_F1s.append(micro_f1)
        print("Macro F1: {:.2%}".format(macro_f1), file=sys.stderr)
        print("Micro F1: {:.2%}".format(micro_f1), file=sys.stderr)
        # update maximum macro F-score and store the most successful model
        if macro_f1 > best_macro_f1:
            best_i = i
            best_macro_f1 = macro_f1
            joblib.dump(a_segmenter.model, a_path)
        # generate new output files, if necessary
        if a_output:
            for k in test:
                test_fname = fnames[k]
                if test_fname in processed_fnames and \
                   processed_fnames[test_fname] > macro_f1:
                    continue
                processed_fnames[test_fname] = macro_f1
                # fname2gld_pred[test_fname] = [(test_segs[i], pred_segs[i]) \
                #                               for i in
                #                               xrange(
                #                               *fname2range[test_fname])]
                in_fname = ofname2ifname[test_fname]
                out_fname = os.path.join(a_out_dir, test_fname)
                with open(out_fname, "w") as ofile:
                    print("(TEXT", file=ofile)
                    bpar_mate_segment(a_segmenter,
                                      _read_files([in_fname], a_encoding),
                                      a_encoding=a_encoding,
                                      a_ostream=ofile,
                                      a_read_trees=a_read_trees)
                    print(")", file=ofile)
        test_feats = []
        test_segs = []
        del trees[:]
        del out_fnames[:]
        fname2range.clear()
    print("Average macro F1: {:.2%} +/- {:.2%}".format(
        np.mean(macro_F1s), np.std(macro_F1s)), file=sys.stderr)
    print("Average micro F1: {:.2%} +/- {:.2%}".format(
        np.mean(micro_F1s), np.std(micro_F1s)), file=sys.stderr)
    if tp or fp or fn:
        F1_tpfp = (2. * tp / (2. * tp + fp + fn))
    print("F1_{{tp,fp}} {:.2%}".format(F1_tpfp),
          file=sys.stderr)
    return (macro_F1s, micro_F1s, F1_tpfp, best_i)


def main(argv):
    """Read input text and segment it into elementary discourse units.

    Args:
      argv (list[str]): command line arguments

    Returns:
      int: 0 on success, non-0 otherwise

    """
    # process arguments
    parser = argparse.ArgumentParser(
        description="Script for segmenting text into elementary"
        " discourse units.")

    # define global options
    parser.add_argument("-e", "--encoding", help="input encoding of text",
                        nargs=1, type=str, default=DEFAULT_ENCODING)
    parser.add_argument("-s", "--skip-line",
                        help="lines which should be ignored during the "
                        "processing and output without changes"
                        " (defaults to empty lines)",
                        type=str, default="")

    # add type-specific subparsers
    subparsers = parser.add_subparsers(help="type of discourse segmenter"
                                       " to use", dest="dtype")

    # edgseg argument parser
    parser_edseg = subparsers.add_parser(EDSEG,
                                         help="rule-based discourse segmenter"
                                         " for CONLL dependency trees")
    parser_edseg.add_argument("-o", "--output-trees",
                              help="output dependency trees along with"
                              " segments", action="store_true")
    parser_edseg.add_argument("files", help="input files",
                              nargs='*', metavar="file")

    # add bpar arguments
    _set_bpar_mate_args(subparsers, BPARSEG, "BitPar constituency trees",
                        BparSegmenter.DEFAULT_MODEL)
    # add mate arguments
    _set_bpar_mate_args(subparsers, MATESEG, "Mate dependency trees",
                        MateSegmenter.DEFAULT_MODEL)
    args = parser.parse_args()

    # process input files
    ifiles = []
    if hasattr(args, "files"):
        ifiles = _read_files(args.files, args.encoding, args.skip_line)

    # process input with edseg
    if args.dtype == EDSEG:
        edseg_segment(ifiles, args.output_trees)
    # process input with bparseg
    else:
        if args.dtype == BPARSEG:
            read_trees = bpar_read_trees
            read_tok_trees = bpar_read_tok_trees
            trees2segs = bpar_trees2segs
        else:
            read_trees = mate_read_trees
            read_tok_trees = mate_read_tok_trees
            trees2segs = mate_trees2segs

        if args.mode == TRAIN or args.mode == CV:
            # make sure there is a directory for storing the model
            mdir = os.path.dirname(args.model)
            if mdir == '':
                pass
            elif os.path.exists(mdir):
                if not os.path.isdir(mdir) or not os.access(mdir, os.R_OK):
                    print("Can't write to directory '{:s}'.".format(mdir),
                          file=sys.stderr)
            else:
                os.makedirs(mdir)

            if args.dtype == BPARSEG:
                segmenter = BparSegmenter(
                    a_model=BparSegmenter.DEFAULT_PIPELINE)
            else:
                segmenter = MateSegmenter(model=MateSegmenter.DEFAULT_PIPELINE)

            trees, segments = _read_trees_segments(
                args.tree_dir, args.seg_dir,
                args.tree_sfx, args.seg_sfx,
                read_tok_trees, trees2segs,
                args.mode == CV, args.encoding)

            if args.mode == TRAIN:
                segmenter.train(trees, segments, args.model)
            else:
                crossval(segmenter, args.model, trees, segments,
                         args.tree_sfx, args.seg_sfx,
                         bool(args.output_dir), args.output_dir,
                         a_read_trees=read_trees)
        else:
            assert os.path.exists(args.model) and \
                os.path.isfile(args.model) and \
                os.access(args.model, os.R_OK), \
                "Can't read model file '{:s}'.".format(args.model)

            if args.dtype == BPARSEG:
                segmenter = BparSegmenter(a_model=args.model)
            else:
                segmenter = MateSegmenter(model=args.model)

            if args.mode == TEST:
                trees, segments = _read_trees_segments(
                    args.tree_dir, args.seg_dir,
                    args.tree_sfx, args.seg_sfx,
                    a_read_tok_trees=read_tok_trees,
                    a_trees2segs=trees2segs,
                    a_encoding=args.encoding)
                bpar_mate_test(segmenter, trees, segments)
            else:
                bpar_mate_segment(segmenter, ifiles, args.encoding,
                                  a_read_trees=read_trees)


##################################################################
# Main
if __name__ == "__main__":
    main(sys.argv[1:])
