#!/usr/bin/env python

import numpy, h5py, argparse, matplotlib
matplotlib.use('Agg')
import pylab
from pycbc.events import veto


parser = argparse.ArgumentParser()
parser.add_argument('--trigger-file', help='Single ifo trigger file')
parser.add_argument('--veto-file', help='Optional, file of veto segments to remove triggers')
parser.add_argument('--min-snr', type=float, help='Optional, Minimum SNR to plot')
parser.add_argument('--output-file')
parser.add_argument('--newsnr-contours', nargs='*', help="List of newsnr values to draw contours at.", default=[])
args = parser.parse_args()

f = h5py.File(args.trigger_file, 'r')
ifo = f.keys()[0]
f = f[ifo]
snr = f['snr'][:]

# We now need to handle the case where chisq is not actually calculated
# 0 is used as a sentinel value
chisq = f['chisq'][:]
l = chisq == 0
chisq_dof = f['chisq_dof'][:]
chisq /= (chisq_dof * 2 - 2)
chisq[l] = .1

def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

if args.veto_file:
    time = f['end_time'][:]
    locs, segs = veto.indices_outside_segments(time, [args.veto_file], ifo=ifo)
    snr = snr[locs]
    chisq = chisq[locs]

if args.min_snr is not None:
    locs = snr > args.min_snr
    snr = snr[locs]
    chisq = chisq[locs]

r = numpy.logspace(numpy.log(chisq.min()), numpy.log(chisq.max()), 300)
for i, cval in enumerate(args.newsnr_contours):
    snrv = snr_from_chisq(r, cval)
    pylab.plot(snrv, r, color='black', lw=0.5)
    if i == 0:
        label = "$\\hat{\\rho} = %s$" % cval
    else:
        label = "$%s$" % cval
    label_pos_idx = numpy.where(snrv > snr.max() * 0.8)[0][0]
    pylab.text(snrv[label_pos_idx], r[label_pos_idx], label, fontsize=6,
               horizontalalignment='center', verticalalignment='center',
               bbox=dict(facecolor='white', lw=0, pad=0, alpha=0.9))

pylab.hexbin(snr, chisq, gridsize=300, xscale='log', yscale='log', lw=0.04,
             mincnt=1, norm=matplotlib.colors.LogNorm())

ax = pylab.gca()
pylab.grid()   
ax.set_xscale('log')
cb = pylab.colorbar() 
pylab.xlim(snr.min(), snr.max() * 1.1)
pylab.ylim(chisq.min(), chisq.max() * 1.1)
cb.set_label('Trigger Density')
pylab.xlabel('Signal-to-Noise Ratio')
pylab.ylabel('Reduced $\\chi^2$')
pylab.title('%s: Single Detector Trigger Distribution' % ifo)
pylab.savefig(args.output_file, dpi=300)
