#!/usr/bin/env python
"""%(prog)s [options] DIRECTORY [DIRECTORY ...]

Run the free energy analysis for water and octanol in <DIRECTORY>/FEP
and return the octanol-water partition coefficient log P_ow.

DIRECTORY should contain all the files resulting from running
``mdpow.fep.Ghyd.setup()`` and ``mdpow.fep.Goct.setup()`` and the results of
the MD FEP simulations. It relies on the canonical naming scheme (basically:
just use the defaults as in the tutorial).

The dV/dlambda plots can be produced automatically (--plot=auto). If multiple
DIRECTORY arguments are provided then one has to choose the auto option (or
None).

Results are *appended* to a data file with **Output file format**::

 molecule  DeltaG0(kJ/mol) ErrorDeltaG0  "logPow" log P_OW  Error logPow directory

molecule
    molecule name as used in the itp
DeltaG0, ErrorDeltaG0
    transfer free energy water --> octanol, in kJ/mol; error from correlation time
    and error propagation through Simpson's rule (see :meth:`mdpow.fep.Gsolv`)
log P_OW
    base-10 logarithm of the octanol-water partition coefficient; >0: partitions into
    octanol, <0: partitions preferrably into water
wat_ok, octa_ok
    "OK": input data was clean; "BAD": some input data xvg files contained
    unparseable lines that were skipped; this can mean that the data became
    corrupted and caution should be used.
directory
    folder in which the simulations were stored
"""
from __future__ import absolute_import, print_function, division

import os
import mdpow.fep
import mdpow.filelock

import logging
logger = logging.getLogger('mdpow')

def load_gsolv(directory, **kwargs):
    FEPdir = os.path.join(directory, "FEP")

    files = {'water': os.path.join(FEPdir, 'water', 'Ghyd.fep'),
             'octanol': os.path.join(FEPdir, 'octanol', 'Goct.fep'),
             }
    files_not_found = dict(((solvent,fn) for solvent,fn in files.items()
                            if not os.path.exists(fn)))
    if len(files_not_found) > 0:
        raise OSError("Missing input files: %r" % files_not_found)

    permissive = kwargs.pop('permissive', False)
    logger.info("Reading water data %(water)r", files)
    gwat = mdpow.fep.Ghyd(filename=files['water'], basedir=directory, permissive=permissive)
    logger.info("Reading octanol data %(octanol)r", files)
    goct = mdpow.fep.Goct(filename=files['octanol'], basedir=directory, permissive=permissive)

    return gwat, goct


def run_pow(directory, **kwargs):
    """Do the P_OW analysis

    :Arguments:
        *directory*
             directory under which the project is stored
        *plotfile*
             name of a file to plot dV/dlambda graphs to
        *outfile*
             name of a file to append results to
        *energyfile*
             append individual energy terms to this file
        *force*
             force rereading data files [False]

    **Output file format (energies and log P_ow) :**

        molecule   DeltaG0(kJ/mol)    log P_OW    directory

    Removed from output (>0.2.1):
     * wat_ok and octa_ok are "OK" if no lines were skipped in the input data files due to
       file corruption, and "BAD" otherwise.

    """
    import gc

    fmt_result = "%-8s %+8.2f %8.2f  logPow %+8.2f %8.2f   %s\n"
    fmt_energy = "%s  %+8.2f %8.2f   logPow %+8.2f %8.2f  %s\n"
    _datastatus = {True: 'BAD',  # suspect file corruption
                   False: 'OK'}

    def datstat(g):
        return _datastatus[g.contains_corrupted_xvgs()]

    estimator=kwargs.pop('estimator', 'alchemlyb')
    gwat, goct = load_gsolv(directory, permissive=kwargs.pop('permissive',False),)
    transferFE, logPow = mdpow.fep.pOW(goct, gwat, stride=kwargs.pop('stride', None),
                                       start=kwargs.pop('start', 0), stop=kwargs.pop('stop', None),
                                       force=kwargs.pop('force', False), SI=kwargs.pop('SI', False),
                                       estimator=estimator,
                                       method=kwargs.pop('method', 'MBAR'))

    if datstat(gwat) == 'BAD' or datstat(goct) == 'BAD':
            logger.warning("Possible file corruption: water:%s, octanol:%s",
                           datstat(gwat), datstat(goct))

    if kwargs.get('outfile', None):
        with mdpow.filelock.FileLock(kwargs['outfile'], timeout=2):
            with open(kwargs['outfile'], mode='a') as out:
                out.write(fmt_result % \
                              (gwat.molecule, transferFE.value, transferFE.error,
                               logPow.value, logPow.error, directory))
                # removed (>0.2.1):
                # datstat(gwat), datstat(goct) --> wat_ok octa_ok
            logger.info("Wrote results to %(outfile)r", kwargs)

    if kwargs.get('energyfile', None):
        with mdpow.filelock.FileLock(kwargs['energyfile'], timeout=2):
            with open(kwargs['energyfile'], mode='a') as out:
                for g in gwat, goct:
                    out.write(fmt_energy % \
                                  (g.summary(), transferFE.value, transferFE.error,
                                   logPow.value, logPow.error,  directory))
            logger.info("Wrote solvation energy terms to %(energyfile)r", kwargs)

    plotfile = kwargs.get('plotfile', None)
    if estimator == 'alchemlyb':
        plotfile = False
    if plotfile:
        if plotfile == 'auto':
            plotfile = auto_plotfile(directory, gwat)
        import matplotlib
        from matplotlib.pyplot import clf, title, savefig
        clf()
        gwat.plot(color='k', ls='--')
        goct.plot(color='r', ls='-')
        activate_subplot(1)
        title(r"[%s] $\Delta A^0 = %.2f\pm%.2f$ kJ/mol" % \
                  (goct.molecule, transferFE.value, transferFE.error))
        activate_subplot(2)
        title(r"$\log P_{OW} = %.2f\pm%.2f$" % logPow.astuple())
        savefig(plotfile)
        logger.info("Wrote graph to %(plotfile)r", vars())

    del gwat
    del goct
    gc.collect()  # try to free as much memory as possible


def auto_plotfile(directory, gsolv):
    return os.path.join(directory,
                        "dVdl_{0}_pow.pdf".format(gsolv.molecule))


def activate_subplot(numPlot):
    """Make subplot *numPlot* active on the canvas.

    Use this if a simple ``subplot(numRows, numCols, numPlot)``
    overwrites the subplot instead of activating it.
    """
    # see http://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg07156.html
    from pylab import gcf, axes
    numPlot -= 1  # index is 0-based, plots are 1-based
    return axes(gcf().get_axes()[numPlot])


def parse_filename_option(fn):
        if fn.lower() == "none":
            fn = None
        elif fn.lower() == "auto":
            fn = "auto"
        return fn

def realpath_or_none(fn):
        if not fn is None:
            fn = os.path.realpath(fn)
        return fn

if __name__ == "__main__":
    import sys
    import os.path
    import argparse

    parser = argparse.ArgumentParser(usage=__doc__, prog='mdpow-pow',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("directory", nargs='+',
                        help="directory or directories which contain all the files "
                        "resulting from running mdpow.fep.Ghyd.setup() and mdpow.fep.Goct.setup()"
                        "and the results of the MD FEP simulations.",)
    parser.add_argument('--plotfile', dest="plotfile", default='none',
                        help="plot dV/dlambda to FILE; use png or pdf suffix to "
                        "determine the file type. 'auto' generates a pdf file "
                        "DIRECTORY/dVdl_<molname>_pow.pdf and 'none' disables it. "
                        "The plot function is only available for mdpow estimator, "
                        "and is disabled when using alchemlyb estimators.",
                        metavar="FILE")
    parser.add_argument('-o', '--outfile', dest="outfile", default="pow.txt",
                        help="append one-line results summary to FILE",
                        metavar="FILE")
    parser.add_argument('-e', '--energies', dest="energyfile", default="energies.txt",
                        help="append solvation free energies to FILE",
                        metavar="FILE")
    parser.add_argument('--estimator', dest="estimator", default='alchemlyb',
                        choices=['mdpow', 'alchemlyb'],
                        help="choose the estimator implementation to be used, alchemlyb or mdpow estimators")
    parser.add_argument('--method', dest="method", default='MBAR',
                        choices=['TI', 'MBAR', 'BAR'],
                        help="choose the estimator method to calculate free energy",)
    parser.add_argument('--force', dest='force', default=False,
                        action="store_true",
                        help="force rereading all data ")
    SIs = parser.add_mutually_exclusive_group()
    SIs.add_argument('--SI', dest='SI',
                     action="store_true", default=True,
                     help="enable statistical inefficiency (SI) analysis. "
                     "Statistical inefficiency analysis is performed by default when using "
                     "alchemlyb estimators and is always disabled when using mdpow estimator.")
    SIs.add_argument('--no-SI', dest='noSI',
                     action="store_true", default=False,
                     help="disable statistical inefficiency analysis. "
                     "Statistical inefficiency analysis is performed by default when using "
                     "alchemlyb estimators and is disabled when using mdpow estimator. ")
    parser.add_argument('-s', '--stride', dest="stride", type=int, default=1,
                        help="use every N-th datapoint from the original dV/dlambda data.",
                        metavar="N")
    parser.add_argument('--start', dest="start", type=int, default=0,
                        help="start point for the data used from the original dV/dlambda data.")
    parser.add_argument('--stop', dest="stop", type=int, default=None,
                        help="stop point for the data used from the original dV/dlambda data.")
    parser.add_argument('--ignore-corrupted', dest="permissive",
                        action="store_true", default=False,
                        help="skip lines in the md.xvg files that cannot be parsed. "
                        "WARNING: Other lines in the file might have been corrupted in "
                        "such a way that they appear correct but are in fact wrong. "
                        "WRONG RESULTS CAN OCCUR! USE AT YOUR OWN RISK.")
    opts = parser.parse_args()

    if len(opts.directory) > 1 and not opts.plotfile.lower() in ('none', 'auto'):
        logger.fatal("Can only use --plotfile=None or --plotfile=auto with multiple directories.")
        sys.exit(1)

    if opts.estimator == 'mdpow' and opts.SI and not opts.noSI:
        logger.fatal("Statistical inefficiency analysis is only available for estimator 'alchemlyb'.")
        sys.exit(1)

    for directory in opts.directory:
        if not os.path.exists(directory):
            logger.warning("Directory %r not found, skipping...", directory)
            continue
        logger.info("Analyzing directory %r... (can take a while)", directory)

        opts.plotfile = parse_filename_option(opts.plotfile)

        try:
            run_pow(directory, plotfile=opts.plotfile,
                    outfile=realpath_or_none(opts.outfile),
                    energyfile=realpath_or_none(opts.energyfile),
                    force=opts.force, stride=opts.stride, start=opts.start,
                    stop=opts.stop, estimator=opts.estimator,
                    method=opts.method, permissive=opts.permissive, SI=opts.SI and not opts.noSI)
        except (OSError, IOError) as err:
            logger.error("Running analysis in directory %r failed: %s", directory, str(err))
        except Exception as err:
            logger.fatal("Running analysis in directory %r failed", directory)
            logger.exception("Catastrophic problem occurred, see the stack trace for hints.")
            raise
