#!/usr/bin/env python

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot single-detector inspiral triggers in the time-frequency plane along with
a spectrogram of the strain data.
"""

import logging
import argparse
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
import matplotlib.mlab as mlab
import h5py
import lalinspiral
import pycbc.events
import pycbc.pnutils
import pycbc.strain


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--trig-file', required=True,
                    help='HDF5 file containing single triggers')
parser.add_argument('--output-file', required=True, help='Output plot')
parser.add_argument('--bank-file', required=True,
                    help='HDF5 file containing template bank')
parser.add_argument('--veto-file', help='LIGOLW file containing veto segments')
parser.add_argument('--f-low', type=float, default=30,
                    help='Low-frequency cutoff')
parser.add_argument('--rank', choices=['snr', 'newsnr'], default='newsnr',
                    help='Ranking statistic for sorting triggers')
parser.add_argument('--num-loudest', type=int, default=1000,
                    help='Number of loudest triggers to plot')
parser.add_argument('--detector', type=str, required=True)
pycbc.strain.insert_strain_option_group(parser)
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

strain = pycbc.strain.from_cli(opts, pycbc.DYN_RANGE_FAC)

#strain.save_to_wav('strain.wav')

center_time = (opts.gps_start_time + opts.gps_end_time) / 2.

logging.info('Loading trigs')
trig_f = h5py.File(opts.trig_file, 'r')
trigs = trig_f[opts.detector]

snr = np.array(trigs['snr'])
rchisq = np.array(trigs['chisq']) / (np.array(trigs['chisq_dof']) * 2 - 2)
end_time = np.array(trigs['end_time'])
template_ids = np.array(trigs['template_id'])

if opts.veto_file:
    logging.info('Loading veto segments')
    time = trigs['end_time'][:]
    locs, segs = pycbc.events.veto.indices_outside_segments(
        time, [opts.veto_file], ifo=opts.detector)
    end_time = end_time[locs]
    snr = snr[locs]
    rchisq = rchisq[locs]
    template_ids = template_ids[locs]

mask = np.logical_and(end_time > opts.gps_start_time,
                      end_time < opts.gps_end_time)
end_time = end_time[mask]
snr = snr[mask]
rchisq = rchisq[mask]
template_ids = template_ids[mask]

if opts.rank == 'snr':
    rank = snr
elif opts.rank == 'newsnr':
    rank = pycbc.events.newsnr(snr, rchisq)

sorter = np.argsort(rank)[::-1][:opts.num_loudest]
end_time = end_time[sorter]
rank = rank[sorter]
rchisq = rchisq[sorter]
template_ids = template_ids[sorter]

max_rank = max(rank)

logging.info('Loading bank')
bank = h5py.File(opts.bank_file, 'r')
mass1s, mass2s = np.array(bank['mass1']), np.array(bank['mass2'])

f_highs = pycbc.pnutils.f_SchwarzISCO(
        mass1s[template_ids] + mass2s[template_ids])

fig = pl.figure(figsize=(20,10))
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
ax = fig.gca()

logging.info('Plotting strain spectrogram')
Pxx, freq, t = mlab.specgram(strain, NFFT=1024, noverlap=1000,
                             Fs=opts.sample_rate)
median_psd = np.median(Pxx, axis=1)
median_psd_tile = np.tile(np.array([median_psd]).T, (1, len(t)))
Pxx /= median_psd_tile
pc = ax.pcolor(t + opts.gps_start_time - center_time, freq, np.log10(Pxx),
               vmin=0, vmax=3, cmap='afmhot_r')

logging.info('Plotting %d trigs', len(end_time))
for tc, rho, tid, f_high in zip(end_time, rank, template_ids, f_highs):
    track_t, track_f = pycbc.pnutils.get_inspiral_tf(
            tc - center_time, mass1s[tid], mass2s[tid], opts.f_low, f_high)
    if rho == max_rank:
        ax.plot(track_t, track_f, '-', color='#ff0000', zorder=3, lw=2)
    else:
        ax.plot(track_t, track_f, '-', color='#0000ff', zorder=2, alpha=0.02)

ax.set_xlim(opts.gps_start_time - center_time, opts.gps_end_time - center_time)
ax.set_ylim(opts.f_low, opts.sample_rate / 2)
ax.set_yscale('log')
ax.grid(ls='solid', alpha=0.2)
ax.set_xlabel('Time - %.3f (s)' % center_time)
ax.set_ylabel('Frequency (Hz)')
title = '%s - loudest %d triggers by %s - max %s = %.2f' \
        % (opts.channel_name, opts.num_loudest, opts.rank, opts.rank, max_rank)
ax.set_title(title)
note = ("Curves show the PN inspiral only and terminate at the Schwarzschild "
        "ISCO. Spin effects neglected.")
fig.text(0.05, 0.01, note, fontsize=7, transform=fig.transFigure)
cb = fig.colorbar(pc, fraction=0.04, pad=0.01)
cb.set_label('log$_{10}$(ratio to median)')
fig.savefig(opts.output_file, dpi=200)

logging.info('Done')