#!/usr/bin/env python

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot PyCBC's single-detector triggers over the search parameter space.
"""

import logging
import argparse
import inspect
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator
import h5py
import pycbc.pnutils
import pycbc.events
import pycbc.results


class SingleDetTriggerSet(object):
    """
    Provides easy access to the parameters of single-detector CBC triggers.
    """

    def __init__(self, trig_file, bank_file, veto_file, segment_name, detector):
        logging.info('Loading triggers')
        self.trigs_f = h5py.File(trig_file, 'r')
        self.trigs = self.trigs_f[detector]

        if veto_file:
            logging.info('Loading veto segments')
            self.veto_mask, segs = pycbc.events.veto.indices_outside_segments(
                self.trigs['end_time'][:], [veto_file], 
                ifo=detector, segment_name=segment_name)
        else:
            self.veto_mask = slice(len(self.trigs['end_time']))

        logging.info('Loading bank')
        self.bank = h5py.File(bank_file, 'r')

    @classmethod
    def get_param_names(cls):
        "Returns a list of plottable CBC parameter variables."
        return [m[0] for m in inspect.getmembers(cls) \
            if type(m[1]) == property]

    @property
    def template_id(self):
        return np.array(self.trigs['template_id'])[self.veto_mask]

    @property
    def mass1(self):
        return np.array(self.bank['mass1'])[self.template_id]

    @property
    def mass2(self):
        return np.array(self.bank['mass2'])[self.template_id]

    @property
    def spin1z(self):
        return np.array(self.bank['spin1z'])[self.template_id]

    @property
    def spin2z(self):
        return np.array(self.bank['spin2z'])[self.template_id]

    @property
    def mtotal(self):
        return self.mass1 + self.mass2

    @property
    def mchirp(self):
        mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(
            self.mass1, self.mass2)
        return mchirp

    @property
    def eta(self):
        mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(
            self.mass1, self.mass2)
        return eta

    @property
    def effective_spin(self):
        # FIXME assumes aligned spins
        return (self.spin1z * self.mass1 + self.spin2z * self.mass2) \
            / self.mtotal

    @property
    def end_time(self):
        return np.array(self.trigs['end_time'])[self.veto_mask]

    @property
    def template_duration(self):
        return np.array(self.trigs['template_duration'])[self.veto_mask]

    @property
    def snr(self):
        return np.array(self.trigs['snr'])[self.veto_mask]

    @property
    def rchisq(self):
        return np.array(self.trigs['chisq'])[self.veto_mask] \
            / (np.array(self.trigs['chisq_dof'])[self.veto_mask] * 2 - 2)

    @property
    def newsnr(self):
        return pycbc.events.newsnr(self.snr, self.rchisq)


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--single-trig-file', type=str, required=True,
                    help='Path to file containing single-detector triggers in HDF5 format')
parser.add_argument('--bank-file', type=str, required=True,
                    help='Path to file containing template bank in HDF5 format')
parser.add_argument('--veto-file', type=str,
                    help='Optional path to file containing veto segments')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--output-file', type=str, required=True,
                    help='Destination path for plot')
parser.add_argument('--x-var', required=True,
                    choices=SingleDetTriggerSet.get_param_names(),
                    help='Parameter to plot on the x-axis')
parser.add_argument('--y-var', required=True,
                    choices=SingleDetTriggerSet.get_param_names(),
                    help='Parameter to plot on the y-axis')
parser.add_argument('--z-var', required=True,
                    help='Quantity to plot on the color scale',
                    choices=['density', 'max(snr)', 'max(newsnr)'])
parser.add_argument('--detector', type=str, required=True,
                    help='Detector')
parser.add_argument('--grid-size', type=int, default=80,
                    help='Bin resolution (larger = smaller bins)')
parser.add_argument('--log-x', action='store_true',
                    help='Use log scale for x-axis')
parser.add_argument('--log-y', action='store_true',
                    help='Use log scale for y-axis')
parser.add_argument('--min-z', type=float, help='Optional minimum z value')
parser.add_argument('--max-z', type=float, help='Optional maximum z value')
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

trigs = SingleDetTriggerSet(opts.single_trig_file, opts.bank_file,
                            opts.veto_file, opts.segment_name, opts.detector)

x = getattr(trigs, opts.x_var)
y = getattr(trigs, opts.y_var)

hexbin_style = {
    'gridsize': opts.grid_size,
    'mincnt': 1,
    'linewidths': 0.03
}
if opts.log_x:
    hexbin_style['xscale'] = 'log'
if opts.log_y:
    hexbin_style['yscale'] = 'log'
if opts.min_z is not None:
    hexbin_style['vmin'] = opts.min_z
if opts.max_z is not None:
    hexbin_style['vmax'] = opts.max_z

logging.info('Plotting')
fig = pl.figure()
ax = fig.gca()

if opts.z_var == 'density':
    norm = LogNorm()
    hb = ax.hexbin(x, y, norm=norm, vmin=1, **hexbin_style)
    fig.colorbar(hb, ticks=LogLocator(subs=range(10)))
elif opts.z_var == 'max(snr)':
    norm = LogNorm()
    hb = ax.hexbin(x, y, C=trigs.snr, norm=norm, reduce_C_function=max,
                   **hexbin_style)
    fig.colorbar(hb, ticks=LogLocator(subs=range(10)))
elif opts.z_var == 'max(newsnr)':
    hb = ax.hexbin(x, y, C=trigs.newsnr, reduce_C_function=max,
                   **hexbin_style)
    fig.colorbar(hb)

ax.set_xlabel(opts.x_var)
ax.set_ylabel(opts.y_var)
ax.set_title(opts.z_var + ' for ' + opts.detector)
title = '%s of %s triggers over %s and %s' \
    % (opts.z_var, opts.detector, opts.x_var, opts.y_var)
pycbc.results.save_fig_with_metadata(fig, opts.output_file, title=title,
                                     caption='', fig_kwds={'dpi': 200})

logging.info('Done')
