#!/usr/bin/python
""" Make plot of search sensitive distance
"""
import argparse, h5py, numpy, logging, os.path, matplotlib, sys
matplotlib.use('Agg')
from matplotlib.pyplot import cm
import pylab, pycbc.pnutils, pycbc.results, pycbc
from pycbc import sensitivity

parser = argparse.ArgumentParser()
parser.add_argument('--injection-file', 
                   help="HDF format injection result file.")
parser.add_argument('--bins', nargs='*',
                   help="Optional, parameter bin boundaries Ex. 1.3 2.6 4.1")
parser.add_argument('--bin-type', choices=['spin', 'mchirp', 
                                           'total_mass', 'eta'], 
                   default='mchirp',
                   help="Optional, parameter bin boundaries")
parser.add_argument('--sig-bins', nargs='*',
                   help="Optional, boundaries of x-axis significance bins"),
parser.add_argument('--sig-type', choices=['ifar', 'fap'], default='ifar',
                   help="Optional, x-axis significance type")
parser.add_argument('--min-dist', type=float, 
                   help="Optional, minimum sensitive distance to plot")
parser.add_argument('--max-dist', type=float, 
                   help="Optional, maximum sensitive distance to plot")
parser.add_argument('--verbose', action='count')
parser.add_argument('--method', choices=['pylal', 'shell', 'mc'], 
                   default='pylal',
                   help="Optional, choice of volume estimation")
parser.add_argument('--dist-bins', type=int, default=100, 
                   help="Optional, used in conjunction with pylal binned"
                        " volume estimation method. Number of distance bins.")
parser.add_argument('--distance-param', type=str,                
                    help="Optional, used in conjunction with pylal binned"
                         " volume estimation method")
parser.add_argument('--distance-distribution', type=str,
                   help="Optional, used in conjunction with pylal binned"
                        " volume estimation method")
parser.add_argument('--output-file')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
    
logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')

time = f['injections/end_time'][:]

# Get the found (at any FAR)/missed injection indices 
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]

# retrieve injection parameters so we can bin based on them
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]
s1z, s2z = f['injections/spin1z'][:], f['injections/spin2z'][:]
s1x, s2x = f['injections/spin1x'][:], f['injections/spin2x'][:]
s1y, s2y = f['injections/spin1y'][:], f['injections/spin2y'][:]
inc = f['injections/inclination'][:]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)

# Dict to hold possible bin types
values = {}
values['mchirp'] = mchirp
values['eta'] = eta
values['total_mass'] = m1 + m2

#FIXME, this is only correct for SpinTaylor Convention!!!!!
s1 = s1x * numpy.sin(inc) + s1z * numpy.cos(inc)
s2 = s2x * numpy.sin(inc) + s2z * numpy.cos(inc)
values['spin'] = (m1 * s1 + m2 * s2) / (m1 + m2)

# Legend labels for the binning
labels = {}
labels['mchirp'] = "$ M_{chirp} \in [%5.2f, %5.2f] M_\odot $"
labels['total_mass'] = "$ M_{total} \in [%5.2f, %5.2f] M_\odot $"
labels['spin'] = "$Eff Spin \in [%5.2f, %5.2f] $"

if args.sig_type == 'ifar':
    sig = f['found_after_vetoes/ifar'][:]
    sig_exc = f['found_after_vetoes/ifar_exc'][:]
    xlabel = 'Inverse False Alarm Rate (years)'
    if args.sig_bins:
        x_values = [float(v) for v in args.sig_bins]
    else:
        x_values = 10 ** (numpy.arange(0, 4, .05))
        
elif args.sig_type == 'fap':
    sig = f['found_after_vetoes/fap'][:]
    sig_exc = f['found_after_vetoes/fap_exc'][:]
    xlabel = 'False Alarm Probability'
    if args.sig_bins:
        x_values = [float(v) for v in args.sig_bins]
    else:
        x_values = 10.0 ** -numpy.arange(1, 7)
   
color=iter(cm.rainbow(numpy.linspace(0, 1 , len(args.bins)-1)))
fvalues = [sig, sig_exc]
do_labels = [True, False]
alphas = [.8, .3]

fig = pylab.figure()

# Plot each injection parameter bin
for j in range(len(args.bins)-1):
    c = next(color)
    
    # Plot both the inclusive and exclusive significance 
    for sig_val, do_label, alpha in zip(fvalues, do_labels, alphas):
        left =  float(args.bins[j])
        right = float(args.bins[j+1])
        binval = values[args.bin_type]

        # Get distance of missed injections within parameter bin
        mbm = numpy.logical_and(binval[missed] > left, binval[missed] < right)
        m_dist = dist[missed][mbm]
        
        #Abort if the bin has too few triggers to calculate
        if len(m_dist) < 2:
            continue
            
        dists, low_errors, high_errors = [], [], []

        # Calculate each sensitive distance at a given significance threshold
        for x_val in x_values:
            if args.sig_type == 'ifar':
                foundg = found[sig_val >= x_val]
                foundm = found[sig_val < x_val]
            elif args.sig_type == 'fap':
                foundg = found[sig_val <= x_val]
                foundm = found[sig_val > x_val]
            
            # get distances that are found within the bin and above the threshold    
            mbf = numpy.logical_and(binval[foundg] > left, binval[foundg] < right)
            f_dist = dist[foundg][mbf]
            
            # get the distances of inj that are below the threshold
            mbfm = numpy.logical_and(binval[foundm] > left, binval[foundm] < right)
            f_distm = dist[foundm][mbfm]
            
            # add distances of found injections to the missed list
            m_dist_full = numpy.append(m_dist, f_distm)
            
            # Choose the volume estimation method
            if args.method == 'shell': 
                vol, vol_err = sensitivity.volume_shell(f_dist, m_dist_full)
            elif args.method == 'pylal':
                vol, vol_err = sensitivity.volume_binned_pylal(f_dist,
                                             m_dist_full, bins=args.dist_bins)
            elif args.method == 'mc':
                found_mchirp = mchirp[foundg][mbf]
                missed_mchirp = numpy.append(mchirp[missed][mbm], mchirp[foundm][mbfm])
                
                vol, vol_err = sensitivity.volume_montecarlo(f_dist, m_dist_full, 
                       found_mchirp, missed_mchirp,
                       args.distance_param, args.distance_distribution)
                   
            sdist, ehigh, elow = sensitivity.volume_to_distance_with_errors(vol, vol_err)
                        
            dists.append(sdist)
            low_errors.append(elow)
            high_errors.append(ehigh)
            
        label = labels[args.bin_type] % (left, right) if do_label else None
        pylab.plot(x_values, dists, label=label, c=c)
        pylab.plot(x_values, dists, alpha=alpha, c='black')
        pylab.fill_between(x_values, 
                           dists - numpy.array(low_errors), 
                           dists + numpy.array(high_errors), 
                           facecolor=c, edgecolor=c, alpha=alpha)
                           
ax = pylab.gca()
ax.set_xscale('log')

if args.sig_type == 'fap':
    ax.invert_xaxis()

if args.min_dist:
    pylab.ylim(args.min_dist, args.max_dist)
    
pylab.ylabel('Sensitive Distance (Mpc)')
pylab.xlabel(xlabel)

pylab.grid()  
pylab.legend(loc='lower left')

pycbc.results.save_fig_with_metadata(fig, args.output_file, 
     title="Sensitive Distance vs %s: binned by %s, %s estimation" 
            % (args.sig_type, args.bin_type, args.method), 
     caption="Sensitive distance as a function of significance:"
             "Darker lines represent the significance without including "
             "injections in their own background, while lighter lines "
             "include each injection (individual) in its own background. ",
     cmd=' '.join(sys.argv))
             
             
