#!/usr/bin/python
# Copyright (C) 2015 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import h5py
import numpy
import sys
import lal
import matplotlib as mpl; mpl.use('Agg')
import pylab
import pycbc.results
import pycbc.version
from glue import segments

def calculate_time_slide_duration(ifo1_segments, ifo2_segments, offset=0):
   ''' Returns the amount of coincident time between two segmentlists.
   '''

   # dertermine how much coincident time
   duration = abs(ifo1_segments & ifo2_segments.shift(offset))

   # unshift the shifted segmentlist
   ifo2_segments.shift(-offset)

   return duration

# parser command line
parser = argparse.ArgumentParser(usage='pycbc_page_ifar [--options]',
                    description='Plots a cumulative histogram of IFAR for'
                          'coincident foreground triggers and a subset of' 
                          'the coincident time slide triggers.')
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file', type=str, required=True,
                    help='Path to coincident trigger HDF file.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Path to output plot.')
parser.add_argument('--decimation-factor', type=int, required=True,
                    help='Decimation factor used in background estimation.' 
                          'Decimation factor means that every nth time' 
                          'slide kept all of its coincident triggers in' 
                          'the HDF file.')
opts = parser.parse_args()

# read file
fp = h5py.File(opts.trigger_file, 'r')

# get foreground IFAR values and cumulative number for each IFAR value
fore_ifar = fp['foreground/ifar'][:]
fore_ifar.sort()
fore_cumnum = numpy.arange(len(fore_ifar), 0, -1)

# get expected foreground IFAR values and cumulative number for each IFAR value
expected_ifar = numpy.logspace(-8, 2, num=100, endpoint=True, base=10.0)
expected_cumnum = fp.attrs['foreground_time'] / expected_ifar / lal.YRJUL_SI 

# get background timeslide IDs and IFAR values
back_tsid = fp['background/timeslide_id'][:]
back_ifar = fp['background/ifar'][:]

# get IFO segments
h1_segments = segments.segmentlist([segments.segment(s,e) for s,e in zip(fp['segments']['H1']['start'][:],fp['segments']['H1']['end'][:])])
l1_segments = segments.segmentlist([segments.segment(s,e) for s,e in zip(fp['segments']['L1']['start'][:],fp['segments']['L1']['end'][:])])

# make figure
fig = pylab.figure(1)

# plot the expected background
pylab.loglog(expected_ifar, expected_cumnum, linestyle='--', linewidth=2, color='black', label='Expected Background')

# plot the counting error
error_plus = expected_cumnum + numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
xs, ys = pylab.poly_between(expected_ifar, error_minus, error_plus)
pylab.fill(xs, ys, facecolor='y', alpha=0.4, label='$N^{1/2}$ Errors')

# plot the counting error
error_plus = expected_cumnum + 2 * numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - 2*numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
xs, ys = pylab.poly_between(expected_ifar, error_minus, error_plus)
pylab.fill(xs, ys, facecolor='y', alpha=0.2, label='$2N^{1/2}$ Errors')

# get a unique list of timeslide_ids and loop over them
id_limit = (len(back_tsid)/2) // opts.decimation_factor
tsids = [x for x in range(-id_limit*opts.decimation_factor, (id_limit+1)*opts.decimation_factor, opts.decimation_factor) if x != 0]
for tsid in tsids:

    # find all triggers in this time slide
    ts_indx = numpy.where(back_tsid == tsid)

    # calculate the amount of coincident time in this time slide
    offset = tsid*fp.attrs['timeslide_interval']
    back_dur = calculate_time_slide_duration(h1_segments, l1_segments, offset=offset)

    # apply the correction factor for this time slide to its IFAR
    # you need a correction factor because the analyzed time of the time slide
    # is not the same as the analyzed time of the foreground
    ts_ifar = back_ifar[ts_indx] * ( fp.attrs['foreground_time'] / back_dur )
    ts_ifar.sort()

    # get the cumulative number of triggers in this time slide
    ts_cumnum = numpy.arange(len(ts_ifar), 0, -1)

    # plot the time slide triggers
    pylab.loglog(ts_ifar, ts_cumnum, color='gray', alpha=0.4)

# plot the foreground triggers
pylab.loglog(fore_ifar, fore_cumnum, linestyle='None', color='blue', marker='^', label='Foreground')

# format plot
if len(fore_cumnum) > 0:
    pylab.ylim(0.8, 1.1 * len(fore_cumnum))
    pylab.xlim(0.9 * min(fore_ifar))
pylab.grid()
pylab.legend()
pylab.ylabel('Cumulative Number')
pylab.xlabel('Inverse False Alarm Rate (yr)')

# save
caption = 'This is a cumulative historgram of triggers. The blue triangles represent' \
          + 'coincident foreground triggers. The dashed line represents the expected background,' \
          + 'the expected background is determined by the foreground time. The shaded regions represent' \
          + 'counting errors. The gray lines are time slides treated as zero lag, here there are %d time'%(len(tsids)) \
          + 'slides plotted.'
pycbc.results.save_fig_with_metadata(fig, opts.output_file,
     title='Cumulative Number vs. IFAR',
     caption=caption,
     cmd=' '.join(sys.argv))
