#!/usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import itertools
import logging
import sys

import numpy

from matplotlib import use
use('agg')
from matplotlib import pyplot as plt

import pycbc
from pycbc import __version__
from pycbc import (distributions, results)
from pycbc.inference.option_utils import ParseParametersArg
from pycbc.workflow import WorkflowConfigParser
from pycbc.results.scatter_histograms import create_multidim_plot


def cartesian(arrays):
    """ Returns a cartesian product from a list of iterables.
    """
    return numpy.array([numpy.array(element)
                        for element in itertools.product(*arrays)])

# command line usage
parser = argparse.ArgumentParser(
    usage="pycbc_inference_plot_prior [--options]",
    description="Plots prior distributions.")
# add input options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
                    help="A file parsable by "
                         "pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--parameters", nargs="+", action=ParseParametersArg,
                    metavar="PARAM[:LABEL]",
                    help="Only plot the given parameters. May also provide a "
                         "label for each parameter. If provided, the "
                         "parameters can be any of parameters in the "
                         "[variable_params] section, derived parameters from "
                         "them, or any function of them. Syntax for "
                         "functions is python; any math functions in the "
                         "numpy libary may be used. Can optionally also "
                         "specify a LABEL for each parameter. If no LABEL is "
                         "provided, PARAM will used as the LABEL. If LABEL "
                         "is the same as a parameter in "
                         "pycbc.waveform.parameters, the label "
                         "property of that parameter will be used. If not "
                         "provided, will plot all of the parameters in the "
                         "[variable_params] section of the config file.")
parser.add_argument("--sections", type=str, nargs="+", default=["prior"],
                    help="Name of section plus subsection with distribution "
                         "configurations, eg. prior-mass1.")
# add output options
parser.add_argument("--nsamples", type=int, default=10000,
                    help="The number of samples to draw from the prior for "
                         "plotting. Defaults is 10000.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Path to output plot.")
# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="")
parser.add_argument("--version", action="version", version=__version__,
                    help="show version number and exit")

# parse the command line
opts = parser.parse_args()

# setup log
pycbc.init_logging(opts.verbose)

# read configuration file
logging.info("Reading configuration files")
cp = WorkflowConfigParser(opts.config_files)

# get prior distribution for each variable parameter
# parse command line values for section and subsection
# if only section then look for subsections
# and add distributions to list
logging.info("Constructing prior")
variable_params = []
dists = []
for sec in opts.sections:
    section = sec.split("-")[0]
    subsec = sec.split("-")[1:]
    if len(subsec):
        subsections = ["-".join(subsec)]
    else:
        subsections = cp.get_subsections(section)
    for subsection in subsections:
        name = cp.get_opt_tag(section, "name", subsection)
        dist = distributions.distribs[name].from_config(
                                            cp, section, subsection)
        variable_params += dist.params
        dists.append(dist)
variable_params = sorted(variable_params)
ndim = len(variable_params)

# construct class that will return draws from the prior
logging.info("Drawing samples")
prior = distributions.JointDistribution(variable_params, *dists)
samples = prior.rvs(opts.nsamples)

# figure out what parameters to plot
if opts.parameters is None:
    parameters = variable_params
    labels = {param: param for param in variable_params}
else:
    parameters = opts.parameters
    labels = opts.parameters_labels

logging.info("Plotting")
fig, axis_dict = create_multidim_plot(
    parameters, samples,
    labels=labels,
    plot_marginal=True,
    marginal_percentiles=[5, 50, 95],
    plot_scatter=False,
    plot_density=True,
    plot_contours=True,
    contour_percentiles=[50, 90],)


# set DPI
fig.set_dpi(200)

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join([param for param in variable_params])
}
caption = """This plot shows the probability density function (PDF) from the 
prior distributions."""
title = """Prior Distributions for {parameters}""".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
logging.info("Done")
