#!/bin/env  python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers
"""
import argparse, h5py, logging, itertools, copy
from scipy.interpolate import interp1d  
from pycbc.future import numpy
from itertools import izip
from pycbc.events import veto, coinc
    
def load_coincs(coinc_files):
    cols = ['stat', 'time1', 'time2', 'trigger_id1', 'trigger_id2', 
            'timeslide_id', 'template_id', 'decimation_factor']
    data = {}
    for key in cols:
        data[key] = []   
    for cfile in coinc_files:
        try:
            logging.info('reading %s' % cfile)
            f = h5py.File(cfile, "r")
            for key in data:
                data[key].append(f[key][:])        
        except:
            continue            
    for key in data:
        data[key] = numpy.concatenate(data[key])              
    return data, dict(f.attrs), f['segments/coinc/start'][:], f['segments/coinc/end'][:], f['segments']

def calculate_fan_map(combined_stat, dec):
    """ Return a function to map between false alarm number (FAN) and the
    combined ranking statistic.
    """
    stat_sorting = combined_stat.argsort()    
    combined_stat = combined_stat[stat_sorting]
    fan = dec[stat_sorting][::-1].cumsum()[::-1]    
    return interp1d(combined_stat, fan, fill_value=1, bounds_error=False) 

def sec_to_year(sec):
    return sec / (3.15569e7)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, 
                    help='Size in seconds to maximize coinc triggers')
parser.add_argument('--zero-lag-coincs', nargs='+',
                    help="Files containing the injection zerolag coincidences")
parser.add_argument('--mixed-coincs-inj-full', nargs='+',
                    help="Files containing the mixed injection/clean data "
                         "time slides")
parser.add_argument('--mixed-coincs-full-inj', nargs='+', 
                    help="Files containing the mixed clean/injection data "
                         "time slides")
parser.add_argument('--full-data-background', 
                    help='background file from full data for use in analyzing injection coincs')
parser.add_argument('--veto-window', type=float, 
                    help='window around each zerolag trigger to window out')
parser.add_argument("--ranking-statistic-threshold", type=float,
                    help="Minimum value of the ranking statistic to calculate"
                         " a unique inclusive background.")
parser.add_argument('--output-file')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info("Loading coinc zerolag triggers")    
zdata, attrs, start, end, seg = load_coincs(args.zero_lag_coincs)   
interval = attrs['timeslide_interval']
zcid = coinc.cluster_coincs(zdata['stat'], zdata['time1'], zdata['time2'], 
                          zdata['timeslide_id'], interval, args.cluster_window)
                     
logging.info("Loading coinc full inj triggers")    
fidata, _, _, _, _ = load_coincs(args.mixed_coincs_full_inj)   
ficid = coinc.cluster_coincs(fidata['stat'], fidata['time1'], fidata['time2'], 
                         fidata['timeslide_id'], interval, args.cluster_window)
                     
logging.info("Loading coinc inj full triggers")    
izdata, _, _, _, _ = load_coincs(args.mixed_coincs_inj_full)   
ifcid = coinc.cluster_coincs(izdata['stat'], izdata['time1'], izdata['time2'], 
                         izdata['timeslide_id'], interval, args.cluster_window)

f = h5py.File(args.output_file, "w")

f.attrs['detector_1'] = attrs['detector_1']
f.attrs['detector_2'] = attrs['detector_2']
f.attrs['timeslide_interval'] = attrs['timeslide_interval']

# Copy over the segment for coincs and singles
for key in seg.keys():
    f['segments/%s/start' % key] = seg[key]['start'][:]
    f['segments/%s/end' % key] = seg[key]['end'][:]

logging.info('writing zero lag triggers')

if len(zdata['stat']) > 0:
    for key in zdata:
        f['foreground/%s' % key] = zdata[key][zcid]
        

logging.info('calculating statistics excluding zerolag')
fb = h5py.File(args.full_data_background, "r")
background_time = float(fb.attrs['background_time_exc'])
coinc_time = float(fb.attrs['foreground_time_exc'])
back_stat = fb['background_exc/stat'][:]
dec_fac = fb['background_exc/decimation_factor'][:]
fanmap_exc = calculate_fan_map(back_stat, dec_fac)

f.attrs['background_time_exc'] = background_time
f.attrs['foreground_time_exc'] = coinc_time
f.attrs['background_time'] = background_time
f.attrs['foreground_time'] = coinc_time

if len(zdata['stat']) > 0:
    
    fore_fan = fanmap_exc(zdata['stat'][zcid])
    ifar_exc = background_time / fore_fan
    fap = numpy.clip(coinc_time / ifar_exc, 0, 1)
    f['foreground/fan_exc'] = fore_fan
    f['foreground/ifar_exc'] = sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap
    
    logging.info('calculating injection backgrounds')
    ftimes = (zdata['time1'][zcid] + zdata['time2'][zcid]) / 2
    fstats = zdata['stat'][zcid]
    start, end = ftimes - args.veto_window, ftimes + args.veto_window
    
    fan = numpy.zeros(len(ftimes), dtype=numpy.float32)
    ifar = numpy.zeros(len(ftimes), dtype=numpy.float32)
    fap = numpy.zeros(len(ftimes), dtype=numpy.float32)
    
    # We are relying on the injection data set to be the first one, 
    # this is determined
    # by the argument order to pycbc_coinc_findtrigs
    ifstat = izdata['stat'][ifcid]
    if_time = izdata['time1'][ifcid]
    ifsort = if_time.argsort()
    ifsorted = if_time[ifsort]
    if_start, if_end = numpy.searchsorted(ifsorted, start), numpy.searchsorted(ifsorted, end)
    
    fistat = fidata['stat'][ficid]
    fi_time = fidata['time1'][ficid]
    fisort = fi_time.argsort()
    fisorted = fi_time[fisort]
    fi_start, fi_end = numpy.searchsorted(fisorted, start), numpy.searchsorted(fisorted, end)
    
    for i, fstat in enumerate(fstats):
        # If the trigger is quiet enough, then don't calculate a separate 
        # background type, as it would not be significantly different
        if args.ranking_statistic_threshold and fstat < args.ranking_statistic_threshold:
            fan[i] = fore_fan[i]
            ifar[i] = ifar_exc[i]
            fap[i] = fap_exc[i]    
            
        v1 = fisort[fi_start[i]:fi_end[i]]
        v2 = ifsort[if_start[i]:if_end[i]]
        
        inj_stat = numpy.concatenate([ifstat[v2], fistat[v1], back_stat])
        inj_dec = numpy.concatenate([numpy.repeat(1, len(v1) + len(v2)), dec_fac])
        fanmap = calculate_fan_map(inj_stat, inj_dec)
        
        fan[i] = fanmap(fstat)
        ifar[i] = background_time / fan[i]
        fap[i] = numpy.clip(coinc_time / ifar[i], 0, 1)

    f['foreground/fan'] = fan
    f['foreground/ifar'] = sec_to_year(ifar)
    f['foreground/fap'] = fap                                                
logging.info("Done") 
    
