#!/usr/bin/python
""" Make plot of search sensitive distance
"""
import argparse, h5py, numpy, logging, matplotlib, sys
matplotlib.use('Agg')
from matplotlib.pyplot import cm
import pylab, pycbc.pnutils, pycbc.results, pycbc, pycbc.version
from pycbc import sensitivity

parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='count')
parser.add_argument('--injection-file', required=True,
                    help="Required. HDF format injection result file.")
parser.add_argument('--output-file', required=True) # DOCUMENTME
parser.add_argument('--bin-type', choices=['spin', 'mchirp',
                                           'total_mass', 'eta'],
                    default='mchirp',
                    help="Parameter used to bin injections. Default 'mchirp'")
# FIXME: Make bins options properly optional, if no bins are specified then
# plot everything together
parser.add_argument('--bins', nargs='+',
                    help="Required. Parameter bin boundaries, ex. 1.3 2.6 4.1")
parser.add_argument('--sig-type', choices=['ifar', 'fap', 'stat'],
                    default='ifar',
                    help="x-axis significance measure. Default 'ifar'")
parser.add_argument('--sig-bins', nargs='*',
                   help="Boundaries of x-axis significance bins. If not given"
                        ", hard-coded defaults will be used"),
parser.add_argument('--dist-type', choices=['distance', 'volume', 'vt'],
                    default='distance')
parser.add_argument('--log-dist', action='store_true', 
                    help='Plot the sensitive distance axis in log scale')
parser.add_argument('--min-dist', type=float,
                    help="Lower y-axis limit for sensitive distance")
parser.add_argument('--max-dist', type=float,
                    help="Upper y-axis limit for sensitive distance")
parser.add_argument('--integration-method', choices=['pylal', 'shell', 'mc'],
                    default='pylal',
                    help="Sensitive volume estimation method. Default 'pylal'")
parser.add_argument('--dist-bins', type=int, default=100,
                    help="Number of distance bins for 'pylal' volume "
                         "estimation. Default 100")
parser.add_argument('--distance-param', choices=['distance', 'chirp_distance'],
                    help="Parameter D used to generate injection distribution "
                         "over distance, required for 'mc' volume estimation")
parser.add_argument('--distribution',
                    choices=['log', 'uniform', 'distancesquared', 'volume'],
                    help="Form of distribution over D, required by 'mc' method")
parser.add_argument('--limits-param', choices=['distance', 'chirp_distance'],
                    help="Parameter Dlim specifying limits of injection "
                         "distribution, used by 'mc' method. If not given, "
                         "will be set equal to --distance-param")
parser.add_argument('--max-param', type=float,
                    help="Maximum value of Dlim, used by 'mc' method. If not "
                         "given, the maximum injected value will be used")
parser.add_argument('--min-param', type=float,
                    help="Minimum value of Dlim, used by 'mc' method with log "
                         "distribution. If not given, min injected value will "
                         "be used")
parser.add_argument('--spin-frame', choices=['line-of-sight', 'orbit'],
                    default='orbit', help='Frame convention used by injections '
                    'for specifying spin vectors. LAL versions after summer '
                    '2015 should use the orbit convention, which is the '
                    'default choice here.')

args = parser.parse_args()

if len(args.bins) < 2:
    raise RuntimeError("At least 2 injection bin boundaries are required!")

if args.integration_method == 'mc' and (args.distance_param is None or \
                                        args.distribution is None):
    raise RuntimeError("The 'mc' method requires --distance-param and "
                       "--distribution !")
if args.integration_method == 'mc' and args.limits_param is None:
    args.limits_param = args.distance_param

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')

time = f['injections/end_time'][:]

# Get the found (at any FAR)/missed injection indices
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]

# retrieve injection parameters so we can bin based on them
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]
s1z, s2z = f['injections/spin1z'][:], f['injections/spin2z'][:]
s1x, s2x = f['injections/spin1x'][:], f['injections/spin2x'][:]
s1y, s2y = f['injections/spin1y'][:], f['injections/spin2y'][:]
inc = f['injections/inclination'][:]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)

# Dict to hold possible bin types
values = {}
values['mchirp'] = mchirp
values['eta'] = eta
values['total_mass'] = m1 + m2

if args.spin_frame == 'line-of-sight':
    s1 = s1x * numpy.sin(inc) + s1z * numpy.cos(inc)
    s2 = s2x * numpy.sin(inc) + s2z * numpy.cos(inc)
elif args.spin_frame == 'orbit':
    s1, s2 = s1z, s2z
values['spin'] = (m1 * s1 + m2 * s2) / (m1 + m2)

# Legend labels for the binning
labels = {}
labels['mchirp'] = "$ M_{\\rm chirp} \in [%5.2f, %5.2f] M_\odot $"
labels['total_mass'] = "$ M_{\\rm total} \in [%5.2f, %5.2f] M_\odot $"
labels['spin'] = "$\\chi^{\\parallel}_{\\rm eff} \in [%5.2f, %5.2f] $"

ylabel = xlabel = ""

if args.sig_type == 'stat':
    sig = f['found_after_vetoes/stat'][:]
    sig_exc = None
    xlabel = 'Ranking Statistic Value'
    if args.sig_bins:
        x_values = [float(v) for v in args.sig_bins]
    else:
        x_values = numpy.arange(9, 14, .05)
elif args.sig_type == 'ifar':
    sig = f['found_after_vetoes/ifar'][:]
    sig_exc = f['found_after_vetoes/ifar_exc'][:]
    xlabel = 'Inverse False Alarm Rate (years)'
    if args.sig_bins:
        x_values = [float(v) for v in args.sig_bins]
    else:
        x_values = 10 ** (numpy.arange(0, 4, .05))
elif args.sig_type == 'fap':
    sig = f['found_after_vetoes/fap'][:]
    sig_exc = f['found_after_vetoes/fap_exc'][:]
    xlabel = 'False Alarm Probability'
    if args.sig_bins:
        x_values = [float(v) for v in args.sig_bins]
    else:
        x_values = 10.0 ** -numpy.arange(1, 7)

color=iter(cm.rainbow(numpy.linspace(0, 1, len(args.bins)-1)))
fvalues = [sig, sig_exc]
do_labels = [True, False]
alphas = [.8, .3]

fig = pylab.figure()
# Plot each injection parameter bin
for j in range(len(args.bins)-1):
    c = next(color)

    # Plot both the inclusive and exclusive significance
    for sig_val, do_label, alpha in zip(fvalues, do_labels, alphas):
        if sig_val is None:
            continue

        left  = float(args.bins[j])
        right = float(args.bins[j+1])
        binval = values[args.bin_type]

        # Get distance of missed injections within parameter bin
        mbm = numpy.logical_and(binval[missed] > left, binval[missed] < right)
        m_dist = dist[missed][mbm]

        #Abort if the bin has too few triggers to calculate
        if len(m_dist) < 2:
            continue

        vols, vol_errors = [], []

        # Calculate each sensitive distance at a given significance threshold
        for x_val in x_values:
            if args.sig_type == 'ifar' or args.sig_type == 'stat':
                foundg = found[sig_val >= x_val]
                foundm = found[sig_val < x_val]
            elif args.sig_type == 'fap':
                foundg = found[sig_val <= x_val]
                foundm = found[sig_val > x_val]

            # get distances that are found within the bin and above the threshold
            mbf = numpy.logical_and(binval[foundg] > left, binval[foundg] < right)
            f_dist = dist[foundg][mbf]

            # get the distances of inj that are below the threshold
            mbfm = numpy.logical_and(binval[foundm] > left, binval[foundm] < right)
            f_distm = dist[foundm][mbfm]

            # add distances of found injections to the missed list
            m_dist_full = numpy.append(m_dist, f_distm)

            # Choose the volume estimation method
            if args.integration_method == 'shell':
                vol, vol_err = sensitivity.volume_shell(f_dist, m_dist_full)
            elif args.integration_method == 'pylal':
                vol, vol_err = sensitivity.volume_binned_pylal(f_dist,
                                             m_dist_full, bins=args.dist_bins)
            elif args.integration_method == 'mc':
                found_mchirp = mchirp[foundg][mbf]
                missed_mchirp = numpy.append(mchirp[missed][mbm], mchirp[foundm][mbfm])

                vol, vol_err = sensitivity.volume_montecarlo(f_dist, m_dist_full,
                      found_mchirp, missed_mchirp, args.distance_param,
                      args.distribution, args.limits_param,
                      args.max_param, args.min_param)

            sdist, ehigh, elow = sensitivity.volume_to_distance_with_errors(vol, vol_err)

            vols.append(vol)
            vol_errors.append(vol_err)

        vols = numpy.array(vols)
        vol_errors = numpy.array(vol_errors)

        if args.dist_type == 'distance':
            ylabel = 'Sensitive Distance (Mpc)'
            reach, ehigh, elow = sensitivity.volume_to_distance_with_errors(vols, vol_errors)
        elif args.dist_type == 'volume':
            ylabel = "Sensitive Volume (Mpc$^3$)"
            reach, ehigh, elow = vols, vol_errors, vol_errors
        elif args.dist_type == 'vt':
            ylabel = "Volume $\\times$ Time (yr Mpc$^3$)"
            pylab.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
            t = f.attrs['foreground_time_exc'] * 3.16888e-8
            reach, ehigh, elow = vols * t, vol_errors * t, vol_errors * t

        label = labels[args.bin_type] % (left, right) if do_label else None
        pylab.plot(x_values, reach, label=label, c=c)
        pylab.plot(x_values, reach, alpha=alpha, c='black')
        pylab.fill_between(x_values, reach - elow, reach + ehigh,
                           facecolor=c, edgecolor=c, alpha=alpha)

ax = pylab.gca()

if args.log_dist:
    ax.set_yscale('log')

if args.sig_type != 'stat':
    ax.set_xscale('log')

if args.sig_type == 'fap':
    ax.invert_xaxis()

if args.min_dist is not None:
    pylab.ylim(ymin=args.min_dist)
if args.max_dist is not None:
    pylab.ylim(ymax=args.max_dist)

pylab.ylabel(ylabel)
pylab.xlabel(xlabel)

pylab.grid()
pylab.legend(loc='lower left')

pycbc.results.save_fig_with_metadata(fig, args.output_file,
     title="Sensitive %s vs %s: binned by %s using %s method"
            % (args.dist_type.title(), args.sig_type.upper(), 
               args.bin_type, args.integration_method),
     caption="Sensitive %s as a function of significance:"
             "Darker lines represent the significance without including "
             "injections in their own background, while lighter lines "
             "include each injection individually in the background "
             "estimation for that signal."
             "The integration method used is " 
             "based on %s." % (args.dist_type.title(), args.integration_method),
     cmd=' '.join(sys.argv))

