#! /usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import logging
import numpy
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot
import pycbc.results
from pycbc.inference import option_utils
from pycbc.io.inference_hdf import InferenceFile
from pycbc.results.scatter_histograms import create_multidim_plot, \
    get_scale_fac

parser = argparse.ArgumentParser()

parser.add_argument("--input-file", type=str, required=True,
                    help="Results file path.")
parser.add_argument("--frame-number", type=int,
                    help="Number of frames for the movie.")
parser.add_argument("--frame-step", type=int, 
                    help="Step in the sample between frames for the movie. "
                         "Only provide if not --frame-number given.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Output plot path without extension.")
parser.add_argument("--parameters", type=str, nargs="+",
                    metavar="PARAM[:LABEL]",
                    help="Name of parameters to plot in same format "
                         "as for pycbc_inference_plot_posterior.")
parser.add_argument("--thin-start", type=int, default=0,
                    help="Sample number to start collecting samples to plot. "
                         "If none provided, will start at 0.")
parser.add_argument("--thin-end", type=int,
                    help="Sample number to stop collecting samples to plot. "
                         "If none provided, will stop at the last sample.")
parser.add_argument('--verbose', action='store_true')
# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)
# add scatter and density configuration options
option_utils.add_scatter_option_group(parser) 
option_utils.add_density_option_group(parser)

opts = parser.parse_args()
pycbc.init_logging(opts.verbose)

# Get data
logging.info('Loading parameters')
fp, parameters, labels, _ = option_utils.results_from_cli(opts,
                            load_samples=False)

# Define thin interval based on number of frames or frame step
if opts.frame_number and opts.frame_step:
    raise ValueError("Only one of frame-number or frame-step must be provided.")
elif opts.frame_number:
    nframes = opts.frame_number
    thinint = (fp.niterations / (nframes - 1)) - 1
elif opts.frame_step:
    thinint = opts.frame_step
    nframes = (fp.niterations / (thinint + 1)) + 1
else:
    raise ValueError("At least one of frame-number or frame-step must be provided.")

samples = fp.read_samples(parameters, thin_start=0, 
                          thin_interval=thinint, flatten=False)

# Get z-values
if opts.z_arg is not None:
    logging.info("Getting likelihood stats")
    likelihood_stats = fp.read_likelihood_stats(thin_start=0,
                                    thin_interval=thinint, flatten=False)
    zvals, zlbl = option_utils.get_zvalues(fp, opts.z_arg, likelihood_stats)
    show_colorbar = True
    # Set common min and max for colorbar in all plots
    if opts.vmin is None:
        vmin = zvals.min()
    else:
        vmin = opts.vmin
    if opts.vmax is None:
        vmax = zvals.max()
    else:
        vmax = opts.vmax
else:
    zvals = None
    zlbl = None
    vmin = vmax = None
    show_colorbar = False

fp.close()

logging.info('Choosing common characteristics for all figures')
# Set common min and max for axis in all plots
mins, maxs = option_utils.plot_ranges_from_cli(opts)
# add any missing parameters
for p in parameters:
    if p not in mins:
        mins[p] = samples[p].min()
for p in parameters:
    if p not in maxs:
        maxs[p] = samples[p].max()

# Make each figure
# for sorting purposes, we will need to zero-pad the sample numer with the
# appriopriate number of 0's
max_sample_num = nframes * thinint
for frame in range(nframes):
    logging.info('Plotting frame %i' %(frame+1))
    plotargs = samples[:,frame]
    if zvals is not None:
        z = zvals[:,frame]
    else:
        z = None
    sample_num = str(frame * thinint + 1)
    sample_num = sample_num.zfill(len(str(max_sample_num)))
    output = opts.output_file + '-{}.png'.format(sample_num)

    fig, axis_dict = create_multidim_plot(parameters, plotargs, labels=labels,
                        mins=mins, maxs=maxs,
                        plot_marginal=opts.plot_marginal,
                        plot_scatter=opts.plot_scatter,
                            zvals=z, show_colorbar=show_colorbar,
                            cbar_label=zlbl, vmin=vmin, vmax=vmax,
                            scatter_cmap=opts.scatter_cmap,
                        plot_density=opts.plot_density,
                        plot_contours=opts.plot_contours,
                            density_cmap=opts.density_cmap,
                            contour_color=opts.contour_color,
                            use_kombine=opts.use_kombine_kde)

    # Write sample number
    if show_colorbar:
        xtxt = 0.85
    else:
        xtxt = 0.9
    ytxt = 0.95
    scale_fac = get_scale_fac(fig)
    fontsize = 8*scale_fac
    pyplot.annotate('Sample {}'.format(sample_num), xy=(xtxt, ytxt),
        xycoords='figure fraction', horizontalalignment='right',
        verticalalignment='top', fontsize=fontsize)

    fig.savefig(output, bbox_inches='tight', dpi=200)
    pyplot.close()

logging.info('Done')
