#!/usr/bin/python
""" Make table of the foreground coincident events
"""
import h5py, numpy, logging, os.path, argparse, sys, matplotlib
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plot
import pycbc.results.followup, pycbc.pnutils, pycbc.results, pycbc.version

parser = argparse.ArgumentParser()
parser.add_argument('--injection-file',
                    help='The hdf injection file to plot')
parser.add_argument('--axis-type', default='mchirp')
parser.add_argument('--distance-type', default='decisive_distance')
parser.add_argument('--verbose', action='count')
parser.add_argument('--log-distance', action='store_true', default=False)
parser.add_argument('--dynamic', action='store_true', default=False)
parser.add_argument('--gradient-far', action='store_true',
                    help='Show far of found injections as a gradient')
parser.add_argument('--output-file')
parser.add_argument('--missed-on-top', action='store_true',
                    help="Make the missed injections be plotted on top of the found ones")
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
    
logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')
time = f['injections/end_time'][:]
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]
ifar = f['found_after_vetoes/ifar'][:]
ifar_found = f['found_after_vetoes/ifar'][:]

# This hardcodes HL search !!!!!, replace with function that takes or/sky/det
hdist = f['injections/eff_dist_h'][:]
ldist = f['injections/eff_dist_l'][:]
s1z = f['injections/spin1z'][:]
s2z = f['injections/spin2z'][:]
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]

import pycbc.pnutils
vals = {}
vals['mchirp'], eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
vals['spin'] = (m1 * s1z + m2 * s2z) / (m1 + m2)
vals['time'] = time
vals['total_mass'] = m1 + m2

dvals = {}
dvals['decisive_distance'] = numpy.maximum(hdist, ldist)
dvals['chirp_distance'] = pycbc.pnutils.chirp_distance(dvals['decisive_distance'], vals['mchirp'])

labels={'mchirp': 'Chirp Mass',
        'decisive_distance': 'Injected Decisive Distance (Mpc)',
        'chirp_distance': 'Injected Decisive Chirp Distance (Mpc)',
        'time': 'Time (s)',
        'spin': 'Weighted Spin',
       }

if args.missed_on_top:
  fig_title = 'Missed and Found Injections'
else:
  fig_title = 'Found and Missed Injections'

# For speed don't bother plotting really distant missed points
# missed = missed[dvals[args.distance_type][missed] < dvals[args.distance_type][found].max() * 1.1]

fig = plot.figure()
zmissed = args.missed_on_top
zfound = not args.missed_on_top

if not args.gradient_far:
    color = numpy.zeros(len(found))
    ten = numpy.where(ifar_found > 10)[0]
    hundred = numpy.where(ifar_found > 100)[0]
    thousand = numpy.where(ifar_found > 1000)[0]
    color[hundred] = 0.5
    color[thousand] = 1.0
    
    mpoints = plot.scatter(vals[args.axis_type][missed], dvals[args.distance_type][missed], 
                           marker='x', color='black', label='missed', zorder=zmissed)
    points = plot.scatter(vals[args.axis_type][found], dvals[args.distance_type][found], 
                           c=color, linewidth=0, vmin=0, vmax=1,
                           marker='o', label='found', zorder=zfound)
    caption = (fig_title + ": Black x's are missed injections. "
              "Blue circles are found with IFAR < 100 years, green are < 1000 years, and "
              "red are found with IFAR >=1000 years. ")
else:
    fvals = vals[args.axis_type][found]
    fdval = dvals[args.distance_type][found]
    color = 1.0 / ifar_found
    
    # sort so quiet found is on top
    csort = color.argsort()
    fvals = fvals[csort]
    fdval = fdval[csort]
    color = color[csort]

    mpoints = plot.scatter(vals[args.axis_type][missed], dvals[args.distance_type][missed], 
                           marker='x', color='red', label='missed', zorder=zmissed)
    points = plot.scatter(fvals, fdval, c=color, linewidth=0, 
                          norm=matplotlib.colors.LogNorm(),
                          marker='o', label='found', zorder=zfound)
    caption = (fig_title + ": Red x's are missed injections. "
               "Circles are found injections. The color indicates the value of "
               "the false alarm rate. " )
    plot.subplots_adjust(right=0.99)
    try:
        c = plot.colorbar()
        c.set_label('False Alarm Rate $(yr^{-1})$')
    except TypeError:
        # Can't make colorbar if no quiet found injections
        if len(fvals):
            raise

if args.missed_on_top:
  caption += "Missed injection are shown on top of found injections."
else:
  caption += "Found injections are shown on top of missed injections."

ax = plot.gca()
plot.xlabel(labels[args.axis_type])
plot.ylabel(labels[args.distance_type])
plot.grid()

if args.log_distance:
    ax.set_yscale('log')
    
plot.ylim(ymin=1, ymax=dvals[args.distance_type][missed].max()*1.2)

fig_kwds = {}
if '.png' in args.output_file:
    fig_kwds['dpi'] = 200
    
if ('.html' in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt='.5g'))
    legend =  mpld3.plugins.InteractiveLegendPlugin([mpoints, points],
                                                    ['missed', 'found'],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

pycbc.results.save_fig_with_metadata(fig, args.output_file, 
                     fig_kwds=fig_kwds,
                     title='%s: %s vs %s' % (fig_title, args.axis_type, args.distance_type),
                     cmd=' '.join(sys.argv),
                     caption=caption)

