#!/usr/bin/env python
"""%(prog)s [options] DIRECTORY [DIRECTORY ...]

Run the free energy analysis for a solvent  in <DIRECTORY>/FEP
and return DeltaG.

DIRECTORY should contain all the files resulting from running
``mdpow.fep.Ghyd.setup()`` (or the corresponding ``Goct.setup()`` or
``Gcyclohexane.setup()`` and the results of the MD FEP
simulations. It relies on the canonical naming scheme (basically: just
use the defaults as in the tutorial).

The dV/dlambda plots can be produced automatically (--plot=auto). If multiple
DIRECTORY arguments are provided then one has to choose the auto option (or
None).

The total solvation free energy is calculated as

  DeltaG* = -(DeltaG_coul + DeltaG_vdw)

Note that the standard state refers to the "Ben-Naim" standard state
of transferring 1 M of compound in the gas phase to 1 M in the aqueous
phase.

Results are *appended* to a data file with **Output file format**::

          .                 ---------- kJ/mol ---
          molecule solvent  DeltaG*  coulomb  vdw

All observables are quoted with an error estimate, which is derived
from the correlation time and error propagation through Simpson's rule
(see :meth:`mdpow.fep.Gsolv`). It ultimately comes from the error on
<dV/dlambda>, which is estimated as the error to determine the
average.

molecule
    molecule name as used in the itp
DeltaG*
    solvation free energy vacuum --> solvent, in kJ/mol
coulomb
    discharging contribution to the DeltaG*
vdw
    decoupling contribution to the DeltaG*
directory
    folder in which the simulations were stored

"""
from __future__ import absolute_import, print_function, division

import os
import mdpow.fep
import mdpow.filelock

import logging
logger = logging.getLogger('mdpow')

suffix = {'water':'hyd',
          'octanol':'oct',
          'cyclohexane':'cyclo'
      }

def load_gsolv(solvent, directory, **kwargs):
    FEPdir = os.path.join(directory, "FEP")

    files = {solvent: os.path.join(FEPdir, solvent,
                                   'G{0}.fep'.format(suffix[solvent])),
         }
    files_not_found = dict(((solv,fn) for solv,fn in files.items()
                            if not os.path.exists(fn)))
    if len(files_not_found) > 0:
        raise OSError("Missing input files: %r" % files_not_found)

    permissive = kwargs.pop('permissive', False)
    logger.info("[%s] Reading {0} data %r".format(solvent), directory, files[solvent])
    return mdpow.fep.Gsolv(filename=files[solvent], basedir=directory, permissive=permissive)

def run_gsolv(solvent, directory, **kwargs):
    """Do the solvation free analysis

    :Arguments:
        *directory*
             directory under which the project is stored
        *plotfile*
             name of a file to plot dV/dlambda graphs to
        *energyfile*
             append individual energy terms to this file
        *force*
             force rereading data files [False]

    **Output file format:**

        molecule DeltaGhyd  Coulomb VDW directory

        (all values as average and error)
    """
    import gc

    fmt_energy = "%s  %s\n"
    _datastatus = {True: 'BAD',  # suspect file corruption
                   False: 'OK'}

    def datstat(g):
        return _datastatus[g.contains_corrupted_xvgs()]

    gsolv = load_gsolv(solvent, directory, permissive=kwargs.pop('permissive',False),)
    logger.info("The solvent is %s.", solvent)

    # for this version. use the method given instead of the one in the input cfg file
    estimator = kwargs.pop('estimator', 'alchemlyb')
    if not estimator in ('mdpow', 'alchemlyb'):
        errmsg = "estimator = %r is not supported, must be 'mdpow' or 'alchemlyb'" % estimator
        logger.error(errmsg)
        raise ValueError(errmsg)

    gsolv.method = kwargs.pop('method', 'MBAR')
    if estimator == 'mdpow':
        if gsolv.method != "TI":
            errmsg = "Method %s is not implemented in MDPOW, use estimator='alchemlyb'" % gsolv.method
            logger.error(errmsg)
            raise ValueError(errmsg)

    energyfile = kwargs.pop('energyfile', None)
    plotfile = kwargs.pop('plotfile', None)

    if not hasattr(gsolv, 'start'):
        gsolv.start = kwargs.pop('start', 0)
    if not hasattr(gsolv, 'stop'):
        gsolv.stop = kwargs.pop('stop', None)
    if not hasattr(gsolv, 'SI'):
        kwargs.setdefault('SI', True)
    else:
        kwargs.setdefault('SI', gsolv.SI)

    # make sure that we have data
    if kwargs['force'] or 'Gibbs' not in gsolv.results.DeltaA:
        logger.info("[%(directory)s] Forcing (re)analysis of hydration free energy data", vars())
        # write out the settings when the analysis is performed
        logger.info("Estimator is %s.", estimator)
        logger.info("Free energy calculation method is %s.", gsolv.method)
        if estimator == 'mdpow':
            # do the analysis (only use keywords that are needed for FEP.analysis())
            akw = {'stride': kwargs.pop('stride', None), 'force': kwargs.pop('force', False)}
            gsolv.analyze(**akw)
        elif estimator == 'alchemlyb':
            if kwargs['SI']:
                logger.info("Statistical inefficiency analysis will be performed.")
            else:
                logger.info("Statistical inefficiency analysis won't be performed.")
            gsolv.analyze_alchemlyb(**kwargs)
    else:
        logger.info("[%(directory)s] Reading existing data from pickle file.", vars())

    gsolv.logger_DeltaA0()

    if datstat(gsolv) == 'BAD':
        logger.warning("[%s] Possible file corruption: %s:%s", directory, solvent, datstat(gsolv))

    if energyfile:
        with mdpow.filelock.FileLock(energyfile, timeout=2):
            with open(energyfile, mode='a') as out:
                out.write(fmt_energy % (gsolv.summary(), directory))
            logger.info("[%s] Wrote solvation energy terms to %r", directory, energyfile)

    if estimator == 'alchemlyb':
        plotfile = False
    if plotfile:
        if plotfile == 'auto':
            plotfile = auto_plotfile(directory, gsolv)
        import matplotlib
        matplotlib.use('agg')  # quick non-interactive plotting
        from pylab import clf, title, savefig
        clf()
        gsolv.plot(color='k', ls='--')
        activate_subplot(1)
        title(r"[{0}] $\Delta G$".format(gsolv.molecule))
        activate_subplot(2)
        savefig(plotfile)
        logger.info("Wrote graph to %(plotfile)r", vars())

    del gsolv
    gc.collect()  # try to free as much memory as possible


def auto_plotfile(directory, gsolv):
    return os.path.join(directory, "dVdl_{0}_{1}.pdf".format(
        gsolv.molecule, gsolv.solvent_type))


def activate_subplot(numPlot):
    """Make subplot *numPlot* active on the canvas.

    Use this if a simple ``subplot(numRows, numCols, numPlot)``
    overwrites the subplot instead of activating it.
    """
    # see http://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg07156.html
    from pylab import gcf, axes
    numPlot -= 1  # index is 0-based, plots are 1-based
    return axes(gcf().get_axes()[numPlot])


def parse_filename_option(fn):
    if fn.lower() == "none":
        fn = None
    elif fn.lower() == "auto":
        fn = "auto"
    return fn

def realpath_or_none(fn):
    if not fn is None:
        fn = os.path.realpath(fn)
    return fn

if __name__ == "__main__":
    import sys
    import os.path
    import argparse

    parser = argparse.ArgumentParser(usage=__doc__, prog='mdpow-solvationenergy',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("directory", nargs='+',
                        help="directory or directories which contain all the files "
                        "resulting from running mdpow.fep.Ghyd.setup() ",)
    parser.add_argument('--plotfile', dest="plotfile", default='none',
                        help="plot dV/dlambda to FILE; use png or pdf suffix to "
                        "determine the file type. 'auto' generates a pdf file "
                        "DIRECTORY/dVdl_<molname>_<solvent>.pdf and 'none' disables it. "
                        "The plot function is only available for mdpow estimator, "
                        "and is disabled when using alchemlyb estimators.",
                        metavar="FILE")
    parser.add_argument('--solvent', '-S', dest='solvent', default='water',
                        choices=['water', 'octanol', 'cyclohexane'], metavar="NAME",
                        help="solvent NAME for compound, 'water', 'octanol', or 'cyclohexane' ")
    parser.add_argument('-o', '--outfile', dest="outfile", default="gsolv.txt",
                        help="append one-line results summary to FILE ",
                        metavar="FILE")
    parser.add_argument('-e', '--energies', dest="energyfile", default="energies.txt",
                        help="append solvation free energies to FILE ",
                        metavar="FILE")
    parser.add_argument('--estimator', dest="estimator", default='alchemlyb',
                        choices=['mdpow', 'alchemlyb'],
                        help="choose the estimator to be used, alchemlyb or mdpow estimators")
    parser.add_argument('--method', dest="method", default='MBAR',
                        choices=['TI', 'MBAR', 'BAR'],
                        help="choose the method to calculate free energy ",)
    parser.add_argument('--force', dest='force', default=False,
                        action="store_true",
                        help="force rereading all data ")
    SIs = parser.add_mutually_exclusive_group()
    SIs.add_argument('--SI', dest='SI',
                     action="store_true", default=True,
                     help="enable statistical inefficiency (SI) analysis. "
                     "Statistical inefficiency analysis is performed by default when using "
                     "alchemlyb estimators and is always disabled when using mdpow estimator.")
    SIs.add_argument('--no-SI', dest='noSI',
                     action="store_true", default=False,
                     help="disable statistical inefficiency analysis. "
                     "Statistical inefficiency analysis is performed by default when using "
                     "alchemlyb estimators and is disabled when using mdpow estimator. ")
    parser.add_argument('-s', '--stride', dest="stride", type=int, default=1,
                        help="use every N-th datapoint from the original dV/dlambda data. ",
                        metavar="N")
    parser.add_argument('--start', dest="start", type=int, default=0,
                        help="start point for the data used from the original dV/dlambda data. ")
    parser.add_argument('--stop', dest="stop", type=int, default=None,
                        help="stop point for the data used from the original dV/dlambda data. ",)
    parser.add_argument('--ignore-corrupted', dest="permissive",
                        action="store_true", default=False,
                        help="skip lines in the md.xvg files that cannot be parsed. "
                        "WARNING: Other lines in the file might have been corrupted in "
                        "such a way that they appear correct but are in fact wrong. "
                        "WRONG RESULTS CAN OCCUR! USE AT YOUR OWN RISK ")
    opts = parser.parse_args()

    if len(opts.directory) > 1 and not opts.plotfile.lower() in ('none', 'auto'):
        logger.fatal("Can only use --plotfile=None or --plotfile=auto with multiple directories.")
        sys.exit(1)

    if opts.estimator == 'mdpow' and opts.SI and not opts.noSI:
        logger.fatal("Statistical inefficiency analysis is only available for estimator 'alchemlyb'.")
        sys.exit(1)

    for directory in opts.directory:
        if not os.path.exists(directory):
            logger.warning("Directory %r not found, skipping...", directory)
            continue
        logger.info("Analyzing directory %r... (can take a while)", directory)

        plotfile = parse_filename_option(opts.plotfile)

        try:
            run_gsolv(opts.solvent, directory, plotfile=plotfile,
                      energyfile=realpath_or_none(opts.energyfile),
                      force=opts.force, stride=opts.stride, start=opts.start,
                      stop=opts.stop, estimator=opts.estimator,
                      method=opts.method, permissive=opts.permissive, SI=opts.SI and not opts.noSI)
        except (OSError, IOError) as err:
            logger.error("Running analysis in directory %r failed: %s", directory, str(err))
        except Exception as err:
            logger.fatal("Running analysis in directory %r failed", directory)
            logger.exception("Catastrophic problem occurred, see the stack trace for hints.")
            raise
