#!/usr/bin/env python
import h5py, argparse, logging, numpy, numpy.random
from pycbc import events, detector
from pycbc.events import veto, coinc
       
parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--veto-files", nargs='*',
                    help="Optional veto file. Triggers within these times are ignored")
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument("--trigger-files", nargs=2,
                    help="File containing the single-detector triggers")
parser.add_argument("--template-bank",
                    help="Template bank file in HDF format")
parser.add_argument("--ranking-statistic", help="The ranking statistic to use", default='newsnr')
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Seconds to add to time-of-flight coincidence window")
parser.add_argument("--timeslide-interval", type=float, default=0,
                    help="Inverval between timeslides in seconds.")
parser.add_argument("--decimation-factor", type=int,
                    help="The factor to reduce the background trigger rate.")
parser.add_argument("--loudest-keep", type=int,
                    help="Number of the loudest triggers to keep from each template.")
parser.add_argument("--loudest-keep-value", type=float,
                    help="Keep all coincident triggers above this value.")
parser.add_argument("--template-fraction-range", help="Optional, format string to"
                    "analyze part of template bank. Format is PART/NUM_PARTS",
                    default="0/1")
parser.add_argument("--cluster-window", help="Optional, window size in seconds to "
                    "cluster coincidences over the bank", type=float)
parser.add_argument("--output-file", help="File to store the coincident triggers")
parser.add_argument("--statistic-files", help="Statistic mapping files", nargs='*', default=[])
args = parser.parse_args()    
                   
if args.verbose:
    logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)

def get_statistic(option, files):
    if option == 'newsnr':
        return NewSNRStatistic(files)
    if option == 'newsnr_cut':
        return NewSNRCutStatistic(files)
    if option == 'phasetd_newsnr':
        return PhaseTDStatistic(files)
    else:
        raise ValueError('%s is not an available detection statistic' % option)

def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin =  int(num_templates / float(pieces) * part)
    tmax =  int(num_templates / float(pieces) * (part+1))
    return tmin, tmax    

class Stat(object):
    def __init__(self, files):
        self.files = {}
        for filename in files:
            f = h5py.File(filename, 'r')
            stat = f.attrs['stat']
            self.files[stat] = f

class NewSNRStatistic(Stat):
    def single(self, trigs):
        """ Read in the single detector information and make a single detector
        statistic. Results can either be a single number or a record array.
        """
        
        dof = 2 * trigs['chisq_dof'] - 2
        newsnr = events.newsnr(trigs['snr'], trigs['chisq'] / dof)
        return numpy.array(newsnr, ndmin=1, dtype=numpy.float32)

    def coinc(self, s1, s2, slide, step):
        """ Calculate the coincident statistic.
        """
        return (s1**2.0 + s2**2.0) ** 0.5

class NewSNRCutStatistic(Stat):
    def single(self, trigs):
        dof = 2 * trigs['chisq_dof'] - 2
        rchisq = trigs['chisq'] / dof
        newsnr = events.newsnr(trigs['snr'], rchisq)
        newsnr[numpy.logical_and(newsnr < 10, rchisq > 2)] = -1
        return newsnr

    def coinc(self, s1, s2, slide, step):
        cstat = (s1**2.0 + s2**2.0) ** 0.5
        cstat[s1==-1] = 0
        cstat[s2==-1] = 0
        return cstat

class PhaseTDStatistic(NewSNRStatistic):
    def single(self, trigs):
        newsnr = NewSNRStatistic.single(self, trigs)
        return numpy.array((newsnr, trigs['coa_phase'], trigs['end_time'])).transpose()

    def coinc(self, s1, s2, slide, step):
        """ Calculate the coincident statistic.
        """
        td = s1[:,2] - s2[:,2] - slide * step
        pd = s1[:,1] - s2[:,1]
        
        tbins = self.files['phasetd_newsnr']['tbins'][:]
        pbins = self.files['phasetd_newsnr']['pbins'][:]
        
        tv = numpy.searchsorted(tbins, td) - 1 
        pv = numpy.searchsorted(pbins, pd) - 1
        
        tv[tv < 0] = 0
        tv[tv >= len(tbins) - 1] = len(tbins) - 2
        pv[pv < 0] = 0
        pv[pv >= len(pbins) - 1] = len(pbins) - 2
        
        m = self.files['phasetd_newsnr']['map'][:][tv, pv]
        m[m < 1] = 1
        
        cstat = (s1[:,0]**2.0 + s2[:,0]**2.0 + 2.0 * numpy.log(m))**0.5  
         
        return cstat

class ReadByTemplate(object):
    def __init__(self, filename, bank=None, segment_name=None, veto_files=[]):
        self.filename = filename
        self.file = h5py.File(filename, 'r')
        self.ifo = self.file.keys()[0]
        self.valid = None
        self.bank = h5py.File(bank) if bank else None

        # Determine the segments which define the boundaries of valid times
        # to use triggers
        from glue.segments import segmentlist, segment
        key = '%s/search/' % self.ifo
        s, e = self.file[key + 'start_time'][:], self.file[key + 'end_time'][:]
        self.segs = veto.start_end_to_segments(s, e).coalesce()
        for vfile in veto_files:
            veto_segs = veto.select_segments_by_definer(vfile, ifo=self.ifo, 
                                                     segment_name=segment_name)
            self.segs = (self.segs - veto_segs).coalesce()
            self.valid = veto.segments_to_start_end(self.segs)
    
    def get_data(self, col, num):
        """ Get a column of data for template with id 'num'
        
        Parameters
        ----------
        col: str
            Name of column to read
        num: int
            The template id to read triggers for
        
        Returns
        -------
        data: numpy.ndarray
            The requested column of data       
        """
        ref = self.file['%s/%s_template' % (self.ifo, col)][num]
        return self.file['%s/%s' % (self.ifo, col)][ref]   
       
    def set_template(self, num):
        """ Set the active template to read from
        
        Parameters
        ----------
        num: int
            The template id to read triggers for
        
        Returns
        -------
        trigger_id: numpy.ndarray
            The indices of this templates triggers
        """
        self.template_num = num
        times = self.get_data('end_time', num)
        
        # Determine which of these template's triggers are kept after
        # applying vetoes
        if self.valid:
            self.keep = veto.indices_within_times(times, self.valid[0], self.valid[1]) 
            logging.info('applying vetoes')
        else:
            self.keep = numpy.arange(0, len(times))

        if self.bank is not None:
            self.param = {} 
            for col in self.bank:
                self.param[col] = self.bank[col][self.template_num]
            
        # Calculate the trigger id by adding the relative offset in self.keep
        # to the absolute beginning index of this templates triggers stored
        # in 'template_boundaries'
        trigger_id = self.keep + self.file['%s/template_boundaries' % self.ifo][num]
        return trigger_id
        
    def __getitem__(self, col):
        """ Return the column of data for current active template after
        applying vetoes
        
        Parameters
        ----------
        col: str
            Name of column to read
        
        Returns
        -------
        data: numpy.ndarray
            The requested column of data  
        """
        if self.template_num == None:
            raise ValueError('You must call set_template to first pick the '
                             'template to read data from')    
        data = self.get_data(col, self.template_num)         
        data = data[self.keep] if self.valid else data
        return data

logging.info('Starting...')

num_templates = len(h5py.File(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

logging.info('Opening first trigger file: %s' % args.trigger_files[0]) 
trigs0= ReadByTemplate(args.trigger_files[0], 
                       args.template_bank, args.segment_name, args.veto_files)
logging.info('Opening second trigger file: %s' % args.trigger_files[1]) 
trigs1 = ReadByTemplate(args.trigger_files[1], 
                        args.template_bank, args.segment_name, args.veto_files)
coinc_segs = (trigs0.segs & trigs1.segs).coalesce()

rank_method = get_statistic(args.ranking_statistic, args.statistic_files)
det0, det1 = detector.Detector(trigs0.ifo), detector.Detector(trigs1.ifo)
time_window = det0.light_travel_time_to_detector(det1) + args.coinc_threshold
logging.info('The coincidence window is %3.1f ms' % (time_window * 1000))

data = {'stat':[], 'decimation_factor':[], 'time1':[], 'time2':[], 
        'trigger_id1':[], 'trigger_id2':[], 'timeslide_id':[], 'template_id':[]}

for tnum in range(tmin, tmax):
    tid0 = trigs0.set_template(tnum)
    tid1 = trigs1.set_template(tnum)

    if len(tid0) == 0:
        continue

    t0 = trigs0['end_time']
    t1 = trigs1['end_time']
    logging.info('Trigs for template %s, %s:%s %s:%s' % \
                (tnum, trigs0.ifo, len(t0), trigs1.ifo, len(t1)))

    i0, i1, slide = coinc.time_coincidence(t0, t1, time_window, args.timeslide_interval)

    logging.info('Coincident Trigs: %s' % (len(i1)))
    
    logging.info('Calculating Single Detector Statistic')
    s0, s1 = rank_method.single(trigs0), rank_method.single(trigs1) 
    
    logging.info('Calculating Multi-Detector Combined Statistic')
    c = rank_method.coinc(s0[i0], s1[i1], slide, args.timeslide_interval)
    
    fi = numpy.where(slide == 0)[0]
    bi = numpy.where(slide != 0)[0]
    logging.info('%s foreground triggers' % len(fi))
    logging.info('%s background triggers' % len(bi))

    if args.loudest_keep:    
        sep = len(bi) - args.loudest_keep
        sep = 0 if sep < 0 else sep

        bsort = numpy.argpartition(c[bi], sep) 
        bl = bi[bsort[0:sep]]
        bh = bi[bsort[sep:]]
        del bsort 
        del bi
        if args.decimation_factor:
            bl = bl[slide[bl] % args.decimation_factor == 0]
        else:
            bl = []
    elif args.loudest_keep_value:
        bl, bh = [], bi[c[bi] > args.loudest_keep_value]
    else:
        bl, bh = [], bi

    ti = numpy.concatenate([bl, bh, fi]).astype(numpy.uint32)
    logging.info('%s after decimation' % len(ti))
   
    g0 = i0[ti]
    g1 = i1[ti] 
    del i0
    del i1
    
    data['stat'] += [c[ti]]
    dec_fac = numpy.repeat([args.decimation_factor, 1, 1], 
                           [len(bl), len(bh), len(fi)]).astype(numpy.uint32)
    data['decimation_factor'] += [dec_fac]
    data['time1'] += [t0[g0]]
    data['time2'] += [t1[g1]]
    data['trigger_id1'] += [tid0[g0]]
    data['trigger_id2'] += [tid1[g1]]
    data['timeslide_id'] += [slide[ti]]
    data['template_id'] += [numpy.zeros(len(ti), dtype=numpy.uint32) + tnum]

for key in data:
    data[key] = numpy.concatenate(data[key])

if args.cluster_window and len(data['stat']) > 0:
    cid = coinc.cluster_coincs(data['stat'], data['time1'], data['time2'], 
                               data['timeslide_id'], args.timeslide_interval, 
                               args.cluster_window)

logging.info('saving coincident triggers')
f = h5py.File(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        f[key] = data[key][cid] if args.cluster_window else data[key]
            
f['segments/coinc/start'], f['segments/coinc/end'] = veto.segments_to_start_end(coinc_segs)

for t in [trigs0, trigs1]:
    f['segments/%s/start' % t.ifo], f['segments/%s/end' % t.ifo] = t.valid
    
f.attrs['timeslide_interval'] = args.timeslide_interval
f.attrs['detector_1'] = det0.name
f.attrs['detector_2'] = det1.name
f.attrs['foreground_time1'] = abs(trigs0.segs)
f.attrs['foreground_time2'] = abs(trigs1.segs)
f.attrs['coinc_time'] = abs(coinc_segs)

if args.timeslide_interval:
    nslides = int(max(abs(trigs0.segs), abs(trigs1.segs)) / args.timeslide_interval)
else:
    nslides = 0
    
f.attrs['num_slides'] = nslides
logging.info('Done')
