#!/usr/bin/python
""" Make table of the foreground coincident events
"""
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--injection-file',
                    help='The hdf injection file to plot')
parser.add_argument('--axis-type', default='mchirp')
parser.add_argument('--distance-type', default='decisive_distance')
parser.add_argument('--verbose', action='count')
parser.add_argument('--log-distance', action='store_true', default=False)
parser.add_argument('--dynamic', action='store_true', default=False)
parser.add_argument('--output-file')
args = parser.parse_args()

import h5py, numpy, logging, os.path
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot
import pycbc.results.followup, pycbc.pnutils

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
    
vals = {}
dvals = {}
logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')
time = f['injections/end_time'][:]
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]
ifar = f['found_after_vetoes/ifar'][:]

# This hardcodes HL search !!!!!, replace with function that takes or/sky/det
hdist = f['injections/eff_dist_h'][:]
ldist = f['injections/eff_dist_l'][:]
s1z = f['injections/spin1z'][:]
s2z = f['injections/spin2z'][:]
dvals['decisive_distance'] = numpy.maximum(hdist, ldist)
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]
import pycbc.pnutils
vals['mchirp'], eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
vals['spin'] = (m1 * s1z + m2 * s2z) / (m1 + m2)
vals['time'] = time
dvals['chirp_distance'] = pycbc.pnutils.chirp_distance(dvals['decisive_distance'], vals['mchirp'])

labels={'mchirp': 'Chirp Mass',
        'decisive_distance': 'Injected Decisive Distance',
        'chirp_distance': 'Injected Decisive Chirp Distance',
        'time': 'Time (s)',
        'spin': 'Weighted Spin',
       }

# For speed don't bother plotting really distant missed points
missed = missed[dvals[args.distance_type][missed] < dvals[args.distance_type][found].max() * 1.1]

ifar_found = f['found_after_vetoes/ifar'][:]
color = numpy.zeros(len(found))
ten = numpy.where(ifar_found > 10)[0]
hundred = numpy.where(ifar_found > 100)[0]
thousand = numpy.where(ifar_found > 1000)[0]
color[hundred] = 0.5
color[thousand] = 1.0
vals['total_mass'] = m1 + m2

fig = plot.figure(figsize=[10,5])
mpoints = plot.scatter(vals[args.axis_type][missed], dvals[args.distance_type][missed], 
                       marker='x', color='black', label='missed', s=60)
points = plot.scatter(vals[args.axis_type][found], dvals[args.distance_type][found], 
                       c=color, s=40, linewidth=0, vmin=0, vmax=1, 
                       marker='o', label='found')
ax = plot.gca()
plot.xlabel(labels[args.axis_type])
plot.ylabel(labels[args.distance_type])
plot.title('Found and Missed Injections')
plot.grid()
plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)

if args.log_distance:
    ax.set_yscale('log')
plot.ylim(ymin=1)

if '.png' in args.output_file:   
    pycbc.results.save_fig_with_metadata(fig, args.output_file, 
                     fig_kwds={'dpi':500},
                     title='%s vs %s' % (args.axis_type, args.distance_type))
elif ('.html' in args.output_file) and args.dynamic:
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt='.5g'))
    legend =  mpld3.plugins.InteractiveLegendPlugin([mpoints, points],
                                                    ['missed', 'found'],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)
    mpld3.save_html(fig, open(args.output_file, 'w'))
else:
    raise ValueError('Cannot save as chosen file type')

