#!/usr/bin/env python

# Copyright (C) 2014 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import glob
import os

import numpy as np
import matplotlib as mpl; mpl.use('Agg')
import matplotlib.pylab as plt

from glue import lal
from glue import segments

from pylal import SnglInspiralUtils

# parse command line
parser = argparse.ArgumentParser(usage='pycbc_plot_rate \
[--options]',
                                 description="Plot trigger rate versus time.")
parser.add_argument('--trigger-files', type=str, nargs="+",
                  help='Input xml files with SnglInspiral triggers.')
parser.add_argument('--output-file', type=str,
                  help='Output image file.')
parser.add_argument('--gps-start-time', type=int,
                  help='Time to start plotting.')
parser.add_argument('--gps-end-time', type=int,
                  help='Time to end plotting.')
opts = parser.parse_args()

# read inspiral triggers
inspiralTable = SnglInspiralUtils.ReadSnglInspiralFromFiles(opts.trigger_files)

# put data into lists
min_rate = 1e-6
num_min  = (opts.gps_end_time - opts.gps_start_time) / 60
mins     = range(num_min)
times    = [float(i) / 60.0 for i in mins]
counts   = [min_rate for i in mins]

# loop over triggers and count them
for trig in inspiralTable:
    trigTime = trig.end_time
    counts[(trigTime - opts.gps_start_time) / 60] += 1

# convert counts to rate in Hz
rates = [c / 60.0 for c in counts]

print len(rates)

# turn data into pylab arrays
x = plt.array(times)
y = plt.array(rates)

# create pylab figure
fig  = plt.figure(1, figsize=(8.1,4.0))

# plot data
p = plt.scatter(x, y, s=0.6)

# format the plot
plt.yscale('log')
plt.ylim(min_rate * 10, max(rates) * 10)
plt.xlim(0, times[-1])
plt.ylabel('Minute-averaged rate (Hz)')
plt.xlabel('Time since %s (hrs)'%opts.gps_start_time)

# save plot
plt.savefig(opts.output_file)
plt.close()
