#!/usr/bin/env python

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot single-detector inspiral triggers in the time-frequency plane along with
a spectrogram of the strain data.
"""

import sys
import logging
import argparse
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
import matplotlib.mlab as mlab
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator
import h5py
import pycbc.events
import pycbc.pnutils
import pycbc.strain
import pycbc.results
import pycbc.version

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trig-file', required=True,
                    help='HDF5 file containing single triggers')
parser.add_argument('--output-file', required=True, help='Output plot')
parser.add_argument('--bank-file', required=True,
                    help='HDF5 file containing template bank')
parser.add_argument('--veto-file', help='LIGOLW file containing veto segments')
parser.add_argument('--f-low', type=float, default=30,
                    help='Low-frequency cutoff')
parser.add_argument('--rank', choices=['snr', 'newsnr'], default='newsnr',
                    help='Ranking statistic for sorting triggers')
parser.add_argument('--num-loudest', type=int, default=1000,
                    help='Number of loudest triggers to plot')
parser.add_argument('--interesting-trig', type=int,
                    help='Index of interesting trigger to highlight')
parser.add_argument('--detector', type=str, required=True)
parser.add_argument('--center-time', type=float,
                    help='Center plot on the given GPS time')
pycbc.strain.insert_strain_option_group(parser)
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

strain = pycbc.strain.from_cli(opts, pycbc.DYN_RANGE_FAC)

if opts.center_time is None:
    center_time = (opts.gps_start_time + opts.gps_end_time) / 2.
else:
    center_time = opts.center_time

fig = pl.figure(figsize=(20,10))
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
ax = fig.gca()

logging.info('Plotting strain spectrogram')
Pxx, freq, t = mlab.specgram(strain, NFFT=1024, noverlap=1000,
                             Fs=opts.sample_rate, mode='psd')
median_psd = np.median(Pxx, axis=1)
median_psd_tile = np.tile(np.array([median_psd]).T, (1, len(t)))
Pxx /= median_psd_tile
norm = LogNorm()
pc = ax.pcolormesh(t + opts.gps_start_time - center_time, freq, Pxx,
                   vmin=1, vmax=1000, norm=norm, cmap='afmhot_r')

logging.info('Loading trigs')
trig_f = h5py.File(opts.trig_file, 'r')
trigs = trig_f[opts.detector]

snr = np.array(trigs['snr'])
rchisq = trigs['chisq'][:] / (trigs['chisq_dof'][:] * 2 - 2)
end_time = trigs['end_time'][:]
template_ids = trigs['template_id'][:]
indices = np.arange(len(end_time))
template_duration = trigs['template_duration'][:]

if opts.veto_file:
    logging.info('Loading veto segments')
    locs, segs = pycbc.events.veto.indices_outside_segments(
        end_time, [opts.veto_file], ifo=opts.detector)
    end_time = end_time[locs]
    snr = snr[locs]
    rchisq = rchisq[locs]
    template_ids = template_ids[locs]
    template_duration = template_duration[locs]
    indices = indices[locs]

mask = np.logical_and(end_time > opts.gps_start_time,
                      end_time < opts.gps_end_time + template_duration)
end_time = end_time[mask]
snr = snr[mask]
rchisq = rchisq[mask]
template_ids = template_ids[mask]
indices = indices[mask]

if mask.any():
    if opts.rank == 'snr':
        rank = snr
    elif opts.rank == 'newsnr':
        rank = pycbc.events.newsnr(snr, rchisq)
        if type(rank) in [np.float32, np.float64]:
            rank = np.array([rank])

    sorter = np.argsort(rank)[::-1][:opts.num_loudest]
    sorted_end_time = end_time[sorter]
    sorted_rank = rank[sorter]
    sorted_rchisq = rchisq[sorter]
    sorted_template_ids = template_ids[sorter]

    try:
        max_rank = max([sorted_rank[i] for i in xrange(len(sorted_rank)) \
                        if sorted_end_time[i] <= opts.gps_end_time])
    except ValueError:
        max_rank = None

    logging.info('Loading bank')
    bank = h5py.File(opts.bank_file, 'r')
    mass1s, mass2s = np.array(bank['mass1']), np.array(bank['mass2'])

    f_highs = pycbc.pnutils.f_SchwarzISCO(
            mass1s[sorted_template_ids] + mass2s[sorted_template_ids])

    logging.info('Plotting %d trigs', len(sorted_end_time))
    for tc, rho, tid, f_high in zip(sorted_end_time, sorted_rank,
                                    sorted_template_ids, f_highs):
        track_t, track_f = pycbc.pnutils.get_inspiral_tf(
                tc - center_time, mass1s[tid], mass2s[tid], opts.f_low, f_high)
        if max_rank and rho == max_rank:
            ax.plot(track_t, track_f, '-', color='#ff0000', zorder=3, lw=2)
        else:
            ax.plot(track_t, track_f, '-', color='#0000ff', zorder=2, alpha=0.02)

    if opts.interesting_trig is not None and opts.interesting_trig in indices:
        interesting_id = np.where(indices == opts.interesting_trig)[0]
        tc = end_time[interesting_id]
        rho = rank[interesting_id]
        interesting_trig_rank = rho
        tid = template_ids[interesting_id]
        f_high = pycbc.pnutils.f_SchwarzISCO(mass1s[tid] + mass2s[tid])
        track_t, track_f = pycbc.pnutils.get_inspiral_tf(
            tc - center_time, mass1s[tid], mass2s[tid], opts.f_low, f_high)
        ax.plot(track_t, track_f, '-', color='#00ff00', zorder=3, lw=2)
    else:
        interesting_trig_rank = None

    if max_rank:
        title = '%s - loudest %d triggers by %s - max %s = %.2f (red curve)' \
            % (opts.channel_name, opts.num_loudest, opts.rank, opts.rank, max_rank)
    else:
        title = '%s - loudest %d triggers by %s' \
            % (opts.channel_name, opts.num_loudest, opts.rank)
    if interesting_trig_rank is not None:
        title += ' - selected %s = %.2f (green curve)' \
            % (opts.rank, interesting_trig_rank)
else:
    title = '%s - no triggers' % opts.channel_name

logging.info('Loading and plotting gates')
for gate_type, hatch_style in [('file', '\\'), ('auto', '/')]:
    try:
        gate_time = trigs['gating/' + gate_type + '/time'][:]
        gate_width = trigs['gating/' + gate_type + '/width'][:]
        gate_pad = trigs['gating/' + gate_type + '/pad'][:]
    except KeyError:
        continue
    for gt, gw, gp in zip(gate_time, gate_width, gate_pad):
        ax.axvspan(gt - gw - center_time, gt + gw - center_time,
                   hatch=hatch_style, facecolor='none', edgecolor='#00ff00')

ax.set_xlim(opts.gps_start_time - center_time, opts.gps_end_time - center_time)
ax.set_ylim(opts.f_low, opts.sample_rate / 2)
ax.set_yscale('log')
ax.grid(ls='solid', alpha=0.2)
ax.set_xlabel('Time - %.3f (s)' % center_time)
ax.set_ylabel('Frequency (Hz)')
ax.set_title(title)
note = ("Curves show the PN inspiral only and terminate at the Schwarzschild "
        "ISCO. Spin effects neglected.")
fig.text(0.05, 0.01, note, fontsize=7, transform=fig.transFigure)
cb = fig.colorbar(pc, fraction=0.04, pad=0.01,
                  ticks=LogLocator(subs=range(10)))
cb.set_label('Power density (normalized to its median over time)')

caption = ("This plot shows the power spectrogram of the strain data, "
           "normalized to its median over time, as a heatmap. The "
           "time-frequency evolution of each single trigger is shown as a blue "
           "curve. The red curve is the loudest trigger by %s. Only the "
           "loudest %d triggers by %s are shown. ") % \
        (opts.rank, opts.num_loudest, opts.rank)
caption += note
pycbc.results.save_fig_with_metadata(
        fig, opts.output_file, cmd=' '.join(sys.argv),
        title='Strain spectrogram and inspiral tracks for %s' % opts.detector,
        caption=caption, fig_kwds={'dpi': 200})

logging.info('Done')
