#! /usr/bin/env python
""" Plots the fraction of injections with their parameter value recovered
within a credible interval versus credible interval.
"""

import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy
import pycbc
from matplotlib import cm
from pycbc import inject
from pycbc import transforms
from pycbc.inference import option_utils

# parse command line
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--output-file", required=True, type=str,
                    help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
                    help="Allows print statements.")
parser.add_argument("--quantiles", nargs="+", type=float,
                    default=numpy.arange(0.05, 1.05, 0.05),
                    help="List of quantiles to plot.")
parser.add_argument("--injection-hdf-group", default="H1/injections",
                    help="HDF group that contains injection values.")
option_utils.add_inference_results_option_group(parser)
option_utils.add_scatter_option_group(parser)
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# read results
_, parameters, labels, samples = option_utils.results_from_cli(opts)

# create figure
fig = plt.figure()
ax = fig.add_subplot(111)

# typecast to list for iteration
parameters = [parameters] if not isinstance(samples, list) else parameters
label = [labels] if not isinstance(labels, list) else labels
samples = [samples] if not isinstance(samples, list) else samples

# dict to hold counts of injections parameters recovered in or out interval
inj_inside = {p : opts.quantiles.size * [0] for p in parameters[0]}
inj_outside = {p : opts.quantiles.size * [0] for p in parameters[0]}

# loop over input files and its samples
logging.info("Plotting")
for input_file, input_parameters, input_samples in zip(
                                        opts.input_file, parameters, samples):

    # read injections from HDF input file
    injs = inject.InjectionSet(input_file, hdf_group=opts.injection_hdf_group)

    # check if need extra parameters than parameters stored in injection file
    _, ts = transforms.get_common_cbc_transforms(opts.parameters,
                                                 injs.table.fieldnames)

    # add parameters not included in injection file
    inj_parameters = transforms.apply_transforms(injs.table, ts)

    # loop over parameters and quantiles
    for p in input_parameters:
        for i, q in enumerate(opts.quantiles):

            # compute quantiles of sampled results
            q *= 100
            low = 50 - q / 2.0
            high = 50 + q / 2.0
            low_val, high_val = numpy.array([
                                         numpy.percentile(input_samples[p], q)
                                         for q in [low, high]])

            # determine if inside or outside interval
            injected_vals = inj_parameters[p]
            for inj_val in injected_vals:
                if inj_val > low_val and inj_val < high_val:
                    inj_inside[p][i] += 1
                else:
                    inj_outside[p][i] += 1

# determine the fraction found within each credible interval
fractions = {p : numpy.zeros(opts.quantiles.size) for p in inj_inside.keys()}
for p in fractions.keys():
    for i, q in enumerate(opts.quantiles):
        fractions[p][i] = inj_inside[p][i] / \
                                 float(inj_outside[p][i] + inj_inside[p][i])

    # plot parameter
    j = parameters[0].index(p)
    ax.plot(opts.quantiles, fractions[p], label=r"{}".format(labels[j][0]))

# set legend
ax.legend()

# set labels
ax.set_ylabel(r"Fraction of Injections Recovered in Credible Interval")
ax.set_xlabel(r"Credible Interval")

# add grid to plot
ax.grid()

# add 1:1 line to plot
ax.plot([0, 1], [0, 1], linestyle="dashed", color="gray", zorder=9)

# save plot
plt.savefig(opts.output_file)

# done
logging.info("Done")

