#!/usr/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Calculate the SNR and CHISQ timeseries for either a chosen template, or 
a specific Nth loudest coincident event.
"""
import sys, logging, argparse, numpy, itertools, pycbc, h5py
from pycbc import vetoes, psd, waveform, version, strain, scheme, fft, filter, events
from pycbc.types import zeros, float32, complex64, TimeSeries

def get_time(statmap, n_loudest, ifo):
     f = h5py.File(statmap, 'r')
     
     try:
        n = f['foreground/stat'][:].argsort()[::-1][n_loudest]
     except:
        return None, None
     
     t = {f.attrs['detector_1']: f['foreground/time1'][:][n],
          f.attrs['detector_2']: f['foreground/time2'][:][n],
          }
     
     return t[ifo], f['foreground/template_id'][:][n]

def select_segment(fname, segment, ifo, time):
    segs = events.select_segments_by_definer(fname, segment, ifo)

    s = numpy.array([t[0] for t in segs])
    e = numpy.array([t[1] for t in segs])    
    
    idx = numpy.searchsorted(s, time) - 1
    return (s[idx], e[idx])

parser = argparse.ArgumentParser(usage='',
    description="Single template gravitational-wave followup")
parser.add_argument('--version', action='version', 
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--output-file')
parser.add_argument("-V", "--verbose", action="store_true", 
                  help="print extra debugging information", default=False )
parser.add_argument("--low-frequency-cutoff", type=float,
                  help="The low frequency cutoff to use for filtering (Hz)")
parser.add_argument("--chisq-bins", default="0", type=str, help=
                    "Number of frequency bins to use for power chisq.")
parser.add_argument("--psd-recalculate-segments", type=int, 
                    help="Number of segments to use before recalculating the PSD", default=0)
parser.add_argument("--approximant", type=str,
                  help="The name of the approximant to use for filtering. ")
parser.add_argument("--mass1", type=float, 
                  help="The mass of the first component object. "
                       "Do not use if giving a statmap file")
parser.add_argument("--mass2", type=float,
                  help="The mass of the second component object. "
                    "Do not use if giving a statmap file.")
parser.add_argument("--spin1z", type=float, default=0,
                  help="The aligned spin of the first component object. "
                    "Do not use if giving a statmap file.")
parser.add_argument("--spin2z", type=float, default=0,
                  help="The aligned pin of the second component object. "
                    "Do not use if giving a statmap file.")
parser.add_argument("--order", type=int,
                  help="The integer half-PN order at which to generate"
                       " the approximant. Default is -1 which indicates to use"
                       " approximant defined default.", default=-1, 
                       choices = numpy.arange(-1, 9, 1))
parser.add_argument("--taper-template", choices=["start","end","startend"],
                    help="For time-domain approximants, taper the start and/or"
                    " end of the waveform before FFTing.")

# alternate option set to pull from statmap file
parser.add_argument("--statmap-file",
        help="Alternative to giving mass/spin parameters. The HDF file "
             "containing the clustered coincident triggers")
parser.add_argument("--n-loudest", type=int,
        help="An integer to choose which coincident trigger is followed up"
             ", only used with --statmap-file")
parser.add_argument("--bank-file",
        help="HDF format template bank file, only used with --statmap-file")
parser.add_argument("--inspiral-segments",
        help="XML file containing the inspiral analysis segments. "
             "Only used with the --statmap-file option")
parser.add_argument("--segment-name", 
        help="name of the segmentlist to read from the inspiral segment file")

# Add options groups
psd.insert_psd_option_group(parser)
strain.insert_strain_option_group(parser)
strain.StrainSegments.insert_segment_option_group(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
opt = parser.parse_args()

f = h5py.File(opt.output_file, 'w')          
ifo = opt.channel_name[0:2]

# If we are reading from the coinc files ######################################
if opt.statmap_file:
    time, tid = get_time(opt.statmap_file, opt.n_loudest, ifo)
    
    if time is None:
        sys.exit()
    
    b = h5py.File(opt.bank_file, 'r')
    opt.mass1 = float(b['mass1'][:][tid])
    opt.mass2 = float(b['mass2'][:][tid])
    opt.spin1z = float(b['spin1z'][:][tid])
    opt.spin2z = float(b['spin2z'][:][tid])
    seg = select_segment(opt.inspiral_segments, opt.segment_name, ifo, time)
    opt.gps_start_time, opt.gps_end_time = seg[0] + opt.pad_data, seg[1] - opt.pad_data
    f.attrs['event_time'] = time

###############################################################################

# Check that the values returned for the options make sense
psd.verify_psd_options(opt, parser)
strain.verify_strain_options(opt, parser)
strain.StrainSegments.verify_segment_options(opt, parser)
scheme.verify_processing_options(opt, parser)
fft.verify_fft_options(opt,parser)
pycbc.init_logging(opt.verbose)

def associate_psd(strain_segments, gwstrain, segments, nsegs, flen, delta_f, flow):
    logging.info("Computing noise PSD")
    def grouper(n, iterable):
        args = [iter(iterable)] * n
        return list([e for e in t if e != None] for t in itertools.izip_longest(*args))

    nsegs = nsegs if nsegs != 0 else len(strain_segments.full_segment_slices)
    groups = grouper(nsegs, strain_segments.full_segment_slices)
    if len(groups[-1]) != len(groups[0]):
        logging.warn('PSD recalculation does not divide equally among analysis'
                     'segments. Make sure that this is what you want')

    psds = []
    for psegs in groups:
        strain_part = gwstrain[psegs[0].start:psegs[-1].stop]
        ppsd = psd.from_cli(opt, flen, delta_f, flow, strain_part, pycbc.DYN_RANGE_FAC)
        psds.append(ppsd)
        for seg in segments:
            if seg.seg_slice in psegs:
                seg.psd = ppsd.astype(float32)
    return psds

ctx = scheme.from_cli(opt)
gwstrain = strain.from_cli(opt, pycbc.DYN_RANGE_FAC)
strain_segments = strain.StrainSegments.from_cli(opt, gwstrain)

class t(object):
    pass
 
row = t()
row.params = t()
row.params.mass1 = opt.mass1
row.params.mass2 = opt.mass2
row.params.spin1z = opt.spin1z
row.params.spin2z = opt.spin2z

chisq_bins = vetoes.SingleDetPowerChisq.parse_option(row, opt.chisq_bins)
if 'params' in opt.approximant:
    opt.approximant = waveform.FilterBank.parse_option(row, opt.approximant)

with ctx:
    fft.from_cli(opt)
    flow = opt.low_frequency_cutoff
    flen = strain_segments.freq_len
    delta_f = strain_segments.delta_f

    logging.info("Making frequency-domain data segments")
    segments = strain_segments.fourier_segments()
    
    logging.info("Calculating the PSDs")
    psds = associate_psd(strain_segments, gwstrain, segments, 
                         opt.psd_recalculate_segments, 
                         flen, delta_f, flow)

    logging.info("Making template: %s" % opt.approximant)
    template = waveform.get_waveform_filter(zeros(flen, dtype=complex64), 
                                    approximant=opt.approximant,
                                    mass1=opt.mass1, mass2=opt.mass2,
                                    spin1z=opt.spin1z, spin2z=opt.spin2z,
                                    taper=opt.taper_template, 
                                    f_lower=flow, delta_f=delta_f,
                                    delta_t=gwstrain.delta_t)
                  
    snrs, chisqs = [], []                                       
    for s_num, stilde in enumerate(segments):
        logging.info("Filtering segment %s" % s_num)
        snr, corr, norm = filter.matched_filter_core(template, stilde, 
                                    psd=stilde.psd,
                                    low_frequency_cutoff=flow)
        snr *= norm
        logging.info("calculating chisq")
        chisq = vetoes.power_chisq(template, stilde, chisq_bins, stilde.psd, 
                                    low_frequency_cutoff = flow)
        chisq /= chisq_bins * 2 - 2
     
        snrs.append(snr[stilde.analyze])
        chisqs.append(chisq[stilde.analyze])
        
    f['snr'] = numpy.concatenate([snr.numpy() for snr in snrs])
    f['snr'].attrs['start_time'] = float(snrs[0].start_time)
    f['snr'].attrs['delta_t'] = snr.delta_t
    
    f['chisq'] = numpy.concatenate([chisq.numpy() for chisq in chisqs])
    f['chisq'].attrs['start_time'] = float(snrs[0].start_time)
    f['chisq'].attrs['delta_t'] = snr.delta_t 
    f.attrs['approximant'] = opt.approximant
    f.attrs['ifo'] = ifo
    
logging.info("Finished")
