#!/usr/bin/env python

# Copyright (C) 2014 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
from ConfigParser import ConfigParser
import glob
import h5py
import matplotlib as mpl; mpl.use('Agg')
import matplotlib.pylab as plt
import numpy as np
import os
import sys

from glue import lal
from glue import segments
from pycbc.events import newsnr

# parse command line
parser = argparse.ArgumentParser(usage='pycbc_plot_glitchgram [--options]',
                  description="Plot template duration versus time for \
                  each trigger, and a colorbar to indicate significance.")
parser.add_argument('--trigger-files', type=str, nargs="+",
                  help='Input trigger files. May be XML or HDF files.')
parser.add_argument('--output-file', type=str,
                  help='Output image file.')
parser.add_argument('--config-file', type=str,
                  help='Output configuration file for the image file.')
parser.add_argument('--ifo', type=str,
                  help='IFO to plot in HDF trigger files. Required if \
                        using HDF trigger files.')
parser.add_argument('--gps-start-time', type=int,
                  help='Time to start plotting.')
parser.add_argument('--gps-end-time', type=int,
                  help='Time to end plotting.')
parser.add_argument('--new-snr', action='store_true', default=False,
                  help='Plots new SNR instead of SNR.')
opts = parser.parse_args()

# sanity check
if opts.ifo:
    for trigger_file in opts.trigger_files:
        if not trigger_file.endswith('hdf'):
            print 'You provided the IFO option for HDF files but not all \
                   your trigger files are HDF.'
            sys.exit()        

# Define some ranges for plot
# each tuple is (lowest significance,high significance,color,dot size,legend)
ranges = [(0,8,'b',3,'0 to 8'),
          (8,16,'g',15,'8 to 16'),
          (16,float('inf'),'r',15,'16 to Infinity')]

# initialize arrays for plotting
hourEndTime   = np.array([])
tmpltDuration = np.array([])
significance  = np.array([])

# check if trigger files are HDF format
if opts.ifo:

    # loop over HDF trigger files
    for trigger_file in opts.trigger_files:

        # read HDF trigger file
        fp = h5py.File(trigger_file, 'r')

        # put data into arrays
        hourEndTime = np.append(hourEndTime,
               (fp[opts.ifo]['end_time'][:] - opts.gps_start_time)/(60.0*60.0))
        tmpltDuration = np.append(tmpltDuration,
               fp[opts.ifo]['template_duration'][:])
        if opts.new_snr:
            reduced_chisq = fp[opts.ifo]['chisq'][:] / ( 2*fp[opts.ifo]['chisq_dof'][:] - 2 )
            significance = np.append(significance,
                      newsnr(fp[opts.ifo]['snr'][:], reduced_chisq))
        else:
            significance = np.append(significance, fp[opts.ifo]['snr'][:])

# otherwise assume trigger files are XML format
else:

    # import pylal dependency for reading XML files
    from pylal import SnglInspiralUtils

    # read inspiral triggers
    inspiralTable = SnglInspiralUtils.ReadSnglInspiralFromFiles(opts.trigger_files)

    # put data into arrays
    for trig in inspiralTable:
        trigTime = trig.end_time + trig.end_time_ns * 10**(-9)
        hourEndTime = np.append(hourEndTime,
                                (trigTime - opts.gps_start_time)/(60.0*60.0))
        tmpltDuration = np.append(tmpltDuration, trig.template_duration)
        if opts.new_snr:
            significance = np.append(significance, trig.get_new_snr())
        else:
            significance = np.append(significance, trig.snr)

# turn data into pylab arrays
x = plt.array(hourEndTime)
y = plt.array(tmpltDuration)
z = plt.array(significance)

# create pylab figure
fig  = plt.figure(1, figsize=(8.1,4.0))

# plot data
for r in ranges:
  subset = plt.logical_and(z>=r[0], z<r[1])
  if len(x[subset]):
    p = plt.scatter(x[subset],y[subset], edgecolor='none', color=r[2],
                      s=r[3], label=r[4])

# format the plot
plt.yscale('log')
plt.ylim(min(y) * 0.9, max(y) * 1.1)
plt.xlim(0, (opts.gps_end_time- opts.gps_start_time)/(60.0*60.0))
plt.ylabel('Template duration (s)')
plt.xlabel('Time since %s (hrs)'%opts.gps_start_time)

# save plot
plt.savefig(opts.output_file)
plt.close()

# save config file
if opts.config_file:
    cp = ConfigParser()
    cp.add_section(opts.output_file)
    cp.set(opts.output_file, 'title', opts.output_file.replace('.png', ''))
    if opts.new_snr:
        cp.set(opts.output_file, 'caption', 'Color key is newSNR is less than 8 (blue), newSNR is between 8 and 16 (green), and newSNR is greater than 16 (red).')
    else:
        cp.set(opts.output_file, 'caption', 'Color key is SNR is less than 8 (blue), SNR is between 8 and 16 (green), and SNR is greater than 16 (red).')
    cp.set(opts.output_file, 'render-function', 'render_glitchgram')
    cp.set(opts.output_file, 'template', 'file_glitchgram.html')
    with open(opts.config_file, 'w') as fp:
        cp.write(fp)
