#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os.path

import MDAnalysis
import time

from DEERpredict.FRETPrediction import FRETPrediction

import logging
logger = logging.getLogger("MDAnalysis.app")



from argparse import ArgumentParser

parser = ArgumentParser()

parser.add_argument("topology", type=str, default=None,
                    help="Topology to analyze.")

parser.add_argument("residues", type=int, nargs=2, default=None,
                  help="residue pair index to compute FRET distances")

parser.add_argument("--trajectory", type=str, default=None, action='append',
                    help="Trajectory to analyze.")

parser.add_argument("--chains", type=str, nargs=2, dest="chains", default=None,
                  help="OPTIONAL: Chains for differentiation on homodimeric proteins")

parser.add_argument("--output", dest="output_file", default="distances.dat",
                  help="the path and name of the output histogram file; the filename will "
                       "have resid 1 and resid 2 inserted before the extension [%(default)s]")

parser.add_argument("--lib1", dest="libname_1", metavar="NAME", default="Alexa 594 50cutoff 3step",
                  help="name of the first position rotamer library [%(default)s]")

parser.add_argument("--lib2", dest="libname_2", metavar="NAME", default="Alexa 488 50cutoff 3step",
                  help="name of the second position rotamer library [%(default)s]")

parser.add_argument("--temp", dest="temp", default=300, type=float,
                    help="Calculation temperature")

parser.add_argument("--r0", dest="r0", default=5.4, type=float,
                    help="Förster radius (Å) for the chromophore pair used [%(default)s]")

parser.add_argument("--plotname", dest="plotname", metavar="FILENAME", default=None,
                  help="plot the histogram to FILENAME (the extensions determines the format) "
                       "By default <outputFile>.pdf.")

parser.add_argument('--replicas', type=int, dest='replicas', metavar='REPLICAS', default=1,
                  help='OPTIONAL: indicate the number of replicas to average for, in case of replica-averaged '
                       'simulation data. These should be concatenated and replica length will be'
                       ' <full trajectory length> / <number of replicas>.')

parser.add_argument("--frame_saving", dest="record_frames", default=False, action='store_true',
                  help="OPTIONAL: Boolean for turning on frame distribution save feature.")

parser.add_argument("--start", dest="start_frame", type=int, default=0,
                  help="Index of the starting frame [%(default)s]")

parser.add_argument("--end", dest="stop_frame", type=int,
                  help="Index of the final frame")

parser.add_argument("--skip", dest="jump_frame", type=int, default=1,
                  help="Jump between frames for analysis")

parser.add_argument("--dyn_k2", dest="k2_calc_dyn", action='store_true',
                  help="K2 dynamic calculation")

parser.add_argument("--stat_k2", dest="k2_calc_static", action='store_true',
                  help="K2 static calculation")

if __name__ == "__main__":
    args = parser.parse_args()

    MDAnalysis.start_logging()
    # load the reference protein structure
    try:
        proteinStructure = MDAnalysis.Universe(args.topology, args.trajectory)
        logger.info("Loading trajectory data as Universe({0}, {1})".format(args.topology, *args.trajectory))
    except TypeError:
        proteinStructure = MDAnalysis.Universe(args.topology)
        logger.info("Loading trajectory data as Universe({0})".format(args.topology))
    except ValueError:
        logger.critical("Protein structure and/or trajectory not correctly specified")
        raise IOError("Protein structure and/or trajectory not correctly specified")
    if args.residues is None or len(args.residues) != 2:
        raise ValueError("Provide residue ids in --residues R1 R2")

    startTime = time.time()
    R = FRETPrediction(proteinStructure,
                       args.residues, chains=args.chains,
                       output_file=args.output_file,
                       libname_1=args.libname_1, libname_2=args.libname_2,
                       temperature=args.temp, replicas=args.replicas,
                       record_frames=args.record_frames, start_frame=args.start_frame,
                       stop_frame=args.stop_frame, jump_frame=args.jump_frame,
                       k2_calc_dynamic=args.k2_calc_dyn, k2_calc_static=args.k2_calc_static, r0=args.r0)
    logger.info("DONE with analysis, elapsed time %6i s" % (int(time.time() - startTime)))

    MDAnalysis.stop_logging()
