#! /usr/bin/env python

import MDAnalysis
import time

from DEERpredict.DEERPrediction import DEERPrediction

import logging
logger = logging.getLogger("MDAnalysis.app")


from argparse import ArgumentParser

# parser = ArgumentParser(usage=__doc__)
parser = ArgumentParser()

parser.add_argument("topology", type=str, nargs=1, default=None,
                    help="Topology to analyze.")

parser.add_argument("residues", type=int, nargs=2, default=None,
                    help="residue pair index to compute FRET distances")

parser.add_argument("--trajectory", type=str, nargs=1, default=None, action='append',
                    help="Trajectory to analyze.")

parser.add_argument("--resid", type=int, nargs=2, dest="residues", default=None,
                  help="REQUIRED: the pair of residues to compute DEER distances for")

parser.add_argument("--chains", type=str, nargs=2, dest="chains", default=None,
                  help="OPTIONAL: Chains for differentiation on homodimeric proteins")

parser.add_argument("--discard", dest="discard_frames", type=int, default=0,
                  help="discard the first N frames [%default]")

parser.add_argument("--output", dest="output_file", default="distances.dat",
                  help="the path and name of the output histogram file; the filename will "
                       "have resid 1 and resid 2 inserted before the extension [%default]")

parser.add_argument("--dcdfilename", dest="dcd_filename", metavar="FILENAME",
                  help="the path and stem of the DCD files of the fitted MTSS rotamers")

parser.add_argument("--libname", dest="libname", metavar="NAME", default="MTSSL 175K X1X2",
                  help="name of the rotamer library [%default]")

parser.add_argument("--plotname", dest="plotname", metavar="FILENAME", default=None,
                  help="plot the histogram to FILENAME (the extensions determines the format) "
                       "By default <outputFile>.pdf.")

parser.add_argument('--replicas', type=int, dest='replicas', metavar='REPLICAS', default=1,
                  help='OPTIONAL: indicate the number of replicas to average for, in case of replica-averaged '
                       'simulation data. These should be concatenated and replica length will be'
                       ' <full trajectory length> / <number of replicas>.')

parser.add_argument("--frame_saving", dest="record_frames", default=False, action='store_true',
                  help="OPTIONAL: Boolean for turning on frame distribution save feature.")

parser.add_argument("--start", dest="start_frame", type=int, default=0,
                  help="Index of the starting frame [%default]")

parser.add_argument("--stop", dest="stop_frame", type=int, default=None,
                  help="Index of the final frame")

parser.add_argument("--skip", dest="jump_frame", type=int, default=1,
                  help="Jump between frames for analysis")

parser.add_argument("--form_factor", dest="form_factor", default=False, action='store_true',
                  help="OPTIONAL: Boolean for turning on experimental form factor calculation.")


if __name__ == "__main__":
    args = parser.parse_args()

    MDAnalysis.start_logging()

    # load the reference protein structure
    try:
        proteinStructure = MDAnalysis.Universe(args.topology, args.trajectory)
        logger.info("Loading trajectory data as Universe({0}, {1})".format(args.topology, *args.trajectory))
    except TypeError:
        proteinStructure = MDAnalysis.Universe(args.topology)
        logger.info("Loading trajectory data as Universe({0})".format(args.topology))
    except ValueError:
        logger.critical("Protein structure and/or trajectory not correctly specified")
        raise IOError("Protein structure and/or trajectory not correctly specified")
    if args.residues is None or len(args.residues) != 2:
        raise ValueError("Provide residue ids in --residues R1 R2")

    startTime = time.time()
    R = DEERPrediction(proteinStructure,
                        args.residues, chains=args.chains,
                        output_file=args.output_file, dcd_filename=args.dcd_filename,
                        libname=args.libname, discard_frames=args.discard_frames, replicas=args.replicas,
                        record_frames=args.record_frames, start_frame=args.start_frame,
                        stop_frame=args.stop_frame, jump_frame=args.jump_frame,
                        form_factor=args.form_factor)
    logger.info("DONE with analysis, elapsed time %6i s" % (int(time.time() - startTime)))

    MDAnalysis.stop_logging()
