#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Runs a sampler to find the posterior distributions.
"""

import argparse
import logging
import numpy
import pycbc.opt
import pycbc.weave
import random
import pycbc
from pycbc import fft
from pycbc import inference
from pycbc import psd
from pycbc import scheme
from pycbc import strain
from pycbc import types
from pycbc import waveform
from pycbc.io.inference_hdf import InferenceFile
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import option_utils

def select_waveform_generator(approximant):
    """ Returns the generator for the approximant.
    """
    if approximant in waveform.fd_approximants():
        return waveform.FDomainCBCGenerator
    elif approximant in waveform.td_approximants():
        return waveform.TDomainCBCGenerator
    elif approximant in waveform.ringdown_fd_approximants:
        if approximant=='FdQNM':
            return waveform.FDomainRingdownGenerator
        elif approximant=='FdQNMmultiModes':
            return waveform.FDomainMultiModeRingdownGenerator
    elif approximant in waveform.ringdown_td_approximants:
        raise ValueError("Time domain ringdowns not supported")
    else:
        raise ValueError("%s is not a valid approximant."%approximant)

def convert_liststring_to_list(lstring):
    """ Checks if an argument of the configuration file is a string of a list
    and returns the corresponding list (of strings)
    """
    if lstring[0]=='[' and lstring[-1]==']':
        lvalue = [str(lstring[1:-1].split(',')[n].strip().strip("'"))
                      for n in range(len(lstring[1:-1].split(',')))]
    return lvalue

# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# add data options
parser.add_argument("--instruments", type=str, nargs="+", required=True,
                    help="IFOs, eg. H1 L1.")
parser.add_argument("--frame-type", type=str, nargs="+",
                    action=types.MultiDetOptionAction,
                    metavar="IFO:FRAME_TYPE",
                    help="Frame type for each IFO.")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
                    help="Low frequency cutoff for each IFO.")
parser.add_argument("--psd-start-time", type=float, default=None,
                    help="Start time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
                    help="End time to use for PSD estimation if different "
                         "from analysis.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
                    choices=inference.likelihood_evaluators.keys(),
                    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
                    help="Seed to use for the random number generator that "
                         "initially distributes the walkers. Default is 0.")

# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
                    help="A file parsable by "
                         "pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--config-overrides", type=str, nargs="+", default=None,
                    metavar="SECTION:OPTION:VALUE",
                    help="List of section:option:value combinations to add "
                         "into the configuration file.")

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file path.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
                    help="Number of iterations to take before saving new "
                         "samples to file.")
parser.add_argument("--checkpoint-fast", action="store_true",
                    help="Do not calculate derived data (eg. ACL or evidence) "
                         "after each checkpoint. Calculate them at the end.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)

# parse command line
opts = parser.parse_args()

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# setup log
pycbc.init_logging(opts.verbose)

# set the seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i" %(opts.seed))

# change measure level for FFT to 0
fft.fftw.set_measure_level(0)

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get strain time series
strain_dict = strain.from_cli_multi_ifos(opts, opts.instruments,
                                         precision="double")

# get strain time series to use for PSD estimation
# if user has not given the PSD time options then use same data as analysis
if opts.psd_start_time and opts.psd_end_time:
    logging.info("Will generate a different time series for PSD estimation")
    psd_opts = opts
    psd_opts.gps_start_time = psd_opts.psd_start_time
    psd_opts.gps_end_time = psd_opts.psd_end_time
    psd_strain_dict = strain.from_cli_multi_ifos(psd_opts, opts.instruments,
                                                precision="double")
elif opts.psd_start_time or opts.psd_end_time:
    raise ValueError("Must give --psd-start-time and --psd-end-time")
else:
    psd_strain_dict = strain_dict

with ctx:

    # FFT strain and save each of the length of the FFT, delta_f, and
    # low frequency cutoff to a dict
    logging.info("FFT strain")
    stilde_dict = {}
    length_dict = {}
    delta_f_dict = {}
    low_frequency_cutoff_dict = {}
    for ifo in opts.instruments:
        stilde_dict[ifo] = strain_dict[ifo].to_frequencyseries()
        length_dict[ifo] = len(stilde_dict[ifo])
        delta_f_dict[ifo] = stilde_dict[ifo].delta_f
        low_frequency_cutoff_dict[ifo] = opts.low_frequency_cutoff

    # get PSD as frequency series
    psd_dict = psd.from_cli_multi_ifos(opts, length_dict, delta_f_dict,
                               low_frequency_cutoff_dict, opts.instruments,
                               strain_dict=psd_strain_dict, precision="double")

    # apply dynamic range factor for saving PSDs since plotting code expects it
    logging.info("Saving PSDs")
    psd_dyn_dict = {}
    for key,val in psd_dict.iteritems():
         psd_dyn_dict[key] = types.FrequencySeries(
                                        psd_dict[key] * pycbc.DYN_RANGE_FAC**2,
                                        delta_f=psd_dict[key].delta_f)

    # save PSD
    with InferenceFile(opts.output_file, "w") as fp:
        fp.write_psds(psds=psd_dyn_dict,
                      low_frequency_cutoff=low_frequency_cutoff_dict)

    # read configuration file
    logging.info("Reading configuration file")
    if opts.config_overrides is not None:
        overrides = [override.split(":") for override in opts.config_overrides]
    else:
        overrides = None
    cp = WorkflowConfigParser(opts.config_files, overrides)

    # sanity check that each parameter in [variable_args] has a priors section
    variable_args = cp.options("variable_args")
    subsections = cp.get_subsections("prior")
    tags = numpy.concatenate([tag.split("+") for tag in subsections])
    if not any(param in tags for param in variable_args):
        raise KeyError("You are missing a priors section in the config file.")

    # get parameters that do not change in sampler
    static_args = dict([(key,cp.get_opt_tags("static_args",key,[])) \
                                         for key in cp.options("static_args")])
    for key,val in static_args.iteritems():
        try:
            static_args[key] = float(val)
            continue
        except:
            pass
        try:
            static_args[key] = convert_liststring_to_list(val) 
        except:
            pass

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    distributions = inference.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior = inference.PriorEvaluator(variable_args, *distributions)

    # select generator that will generate waveform
    # for likelihood evaluator
    logging.info("Setting up sampler")
    generator_function = select_waveform_generator(static_args["approximant"])

    # construct class that will generate waveforms
    generator = waveform.FDomainDetFrameGenerator(generator_function,
                        epoch=stilde_dict.values()[0].epoch,
                        variable_args=variable_args,
                        detectors=opts.instruments,
                        delta_f=delta_f_dict.values()[0], **static_args)

    # construct class that will return the natural logarithm of likelihood
    likelihood = inference.likelihood_evaluators[opts.likelihood_evaluator](
                        generator, stilde_dict, opts.low_frequency_cutoff,
                        psds=psd_dict, prior=prior)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # get distribution to draw from for each walker initial position
    logging.info("Setting walkers initial conditions for varying parameters")

    # check if user wants to use different distributions than prior for setting
    # walkers initial positions
    if len(cp.get_subsections("initial")):
        initial_distributions = inference.read_distributions_from_config(cp,
            section="initial")
    else:
        initial_distributions = inference.read_distributions_from_config(cp,
            section="prior")

    # set the initial positions
    sampler.set_p0(initial_distributions)

    # setup checkpointing
    if opts.checkpoint_interval:

        # determine intervals to run sampler until we save the new samples 
        intervals = [i for i in range(0, opts.niterations, opts.checkpoint_interval)]

        # determine if there is a small bit at the end
        remainder = opts.niterations % opts.checkpoint_interval
        if remainder:
            intervals += [intervals[-1]+remainder]
        else:
            intervals += [opts.niterations]

    # if not checkpointing then set intervals to run sampler in one call
    else:
        intervals = [0, opts.niterations]

    intervals = numpy.array(intervals)

    # check if user wants to skip the burn in
    if not opts.skip_burn_in:
        logging.info("Burn in")
        sampler.burn_in()
        n_burnin = sampler.burn_in_iterations
        logging.info("Used %i burn in samples" % n_burnin)
        with InferenceFile(opts.output_file, "a") as fp:
            # write the burn in results
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)
    else:
        n_burnin = 0


    # increase the intervals to account for the burn in
    intervals += n_burnin

    # loop over number of checkpoints
    for i,start in enumerate(intervals[:-1]):

        end = intervals[i+1]

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for %i to %i out of %i iterations"%(
            start-n_burnin,end-n_burnin,opts.niterations))
        sampler.run(end-start)

        # write new samples
        with InferenceFile(opts.output_file, "a") as fp:

            logging.info("Writing results to file")
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)

            if not opts.checkpoint_fast or end == opts.niterations:

                # compute the acls and write
                logging.info("Computing acls")
                sampler.write_acls(fp, sampler.compute_acls(fp))

                # compute evidence, if supported
                try:
                    lnz, dlnz = sampler.calculate_logevidence(fp)
                    logging.info("Saving evidence")
                    sampler.write_logevidence(fp, lnz, dlnz)
                except NotImplementedError:
                    pass

# exit
logging.info("Done")
