#!/bin/env  python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers
"""
import argparse, h5py, logging, itertools, copy
from scipy.interpolate import interp1d  
from pycbc.future import numpy
from itertools import izip
from pycbc.events import veto, coinc
    
def load_coincs(coinc_files):
    stat, time1, time2, timeslide_id, template_id, decimation_factor, i1, i2 = [], [], [], [], [], [], [], []
    for cfile in coinc_files:
        try:
            logging.info('reading %s' % cfile)
            f = h5py.File(cfile, "r")
            stat.append(f['stat'][:])
            time1.append(f['time1'][:])
            time2.append(f['time2'][:])
            i1.append(f['trigger_id1'][:])
            i2.append(f['trigger_id2'][:])
            timeslide_id.append(f['timeslide_id'][:])
            template_id.append(f['template_id'][:])
            decimation_factor.append(f['decimation_factor'][:])
            attr = dict(f.attrs)
            
        except:
            continue
    return (numpy.concatenate(stat), numpy.concatenate(time1),
           numpy.concatenate(time2), numpy.concatenate(timeslide_id),
           numpy.concatenate(template_id), numpy.concatenate(decimation_factor),
           numpy.concatenate(i1), numpy.concatenate(i2), attr, 
           f['segments/coinc/start'][:], f['segments/coinc/end'][:],
           f['segments']
           )

def calculate_fan_map(combined_stat, dec):
    """ Return a function to map between false alarm number (FAN) and the
    combined ranking statistic.
    """
    stat_sorting = combined_stat.argsort()    
    combined_stat = combined_stat[stat_sorting]
    fan = dec[stat_sorting][::-1].cumsum()[::-1]    
    return interp1d(combined_stat, fan, fill_value=1, bounds_error=False) 

def sec_to_year(sec):
    return sec / (3.15569e7)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--coinc-files', nargs='+', help='List of coincidence files used to calculate the FAP, FAR, etc.')
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, help='Size in seconds to maximize coinc triggers')
parser.add_argument('--veto-window', type=float, help='window around each zerolag trigger to window out')
parser.add_argument('--output-file')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', 
                            level=log_level)

logging.info("Loading coinc triggers")    
stat, t1, t2, sid, tid, dec, i1, i2, attrs, start, end, seg = load_coincs(args.coinc_files)   
logging.info("We have %s triggers" % len(stat))

unclustered_fore_locs = numpy.where((sid == 0))[0]
logging.info('%s unclustered foreground triggers' % len(unclustered_fore_locs))  

logging.info("Clustering coinc triggers (inclusive of zerolag)")
cid = coinc.cluster_coincs(stat, t1, t2, sid, 
                     attrs['timeslide_interval'], 
                     args.cluster_window)
stat_i, t1_i, t2_i, sid_i, tid_i, dec_i, i1_i, i2_i = stat[cid], t1[cid], t2[cid], sid[cid], tid[cid], dec[cid], i1[cid], i2[cid]

ft1, ft2 = t1[unclustered_fore_locs], t2[unclustered_fore_locs]
vt = (ft1 + ft2) / 2.0
veto_start, veto_end = vt - args.veto_window, vt + args.veto_window
v1 = veto.indices_within_times(t1, veto_start, veto_end) 
v2 = veto.indices_within_times(t2, veto_start, veto_end) 
vidx = numpy.unique(numpy.concatenate([v1, v2]))
vi = numpy.zeros(len(t1), dtype=numpy.bool)
vi[:] = True
vi[vidx] = False

veto_time = abs(veto.start_end_to_segments(veto_start, veto_end).coalesce())  

logging.info("Clustering coinc triggers (exclusive of zerolag)")
cid = coinc.cluster_coincs(stat[vi], t1[vi], t2[vi], sid[vi], 
                     attrs['timeslide_interval'], 
                     args.cluster_window)
stat_e, t1_e, t2_e, sid_e, tid_e, dec_e, i1_e, i2_e = stat[vi][cid], t1[vi][cid], t2[vi][cid], sid[vi][cid], tid[vi][cid], dec[vi][cid], i1[vi][cid], i2[vi][cid]

logging.info("Dumping foreground triggers")
f = h5py.File(args.output_file, "w")

f.attrs['detector_1'] = attrs['detector_1']
f.attrs['detector_2'] = attrs['detector_2']
f.attrs['timeslide_interval'] = attrs['timeslide_interval']

fore_locs = (sid_i == 0)
logging.info("%s clustered foreground triggers" % fore_locs.sum())

# Copy over the segment for coincs and singles
for key in seg.keys():
    f['segments/%s/start' % key] = seg[key]['start'][:]
    f['segments/%s/end' % key] = seg[key]['end'][:]

if fore_locs.sum() > 0:
    f['foreground/stat'] = stat_i[fore_locs]
    f['foreground/time1'] = t1_i[fore_locs]
    f['foreground/time2'] = t2_i[fore_locs]
    f['foreground/trigger_id1'] = i1_i[fore_locs]
    f['foreground/trigger_id2'] = i2_i[fore_locs]
    f['foreground/template_id'] = tid_i[fore_locs]

back_locs = numpy.where((sid_i != 0))[0]

if len(back_locs) > 0:
    logging.info("Dumping background triggers (inclusive of zerolag)")
    f['background/stat'] = stat_i[back_locs]
    f['background/time1'] = t1_i[back_locs]
    f['background/time2'] = t2_i[back_locs]
    f['background/trigger_id1'] = i1_i[back_locs]
    f['background/trigger_id2'] = i2_i[back_locs]
    f['background/timeslide_id'] = sid_i[back_locs]
    f['background/template_id'] = tid_i[back_locs]
    f['background/decimation_factor'] = dec_i[back_locs]

    logging.info("Dumping background triggers (exclusive of zerolag)")      
    f['background_exc/stat'] = stat_e
    f['background_exc/time1'] = t1_e
    f['background_exc/time2'] = t2_e
    f['background_exc/trigger_id1'] = i1_e
    f['background_exc/trigger_id2'] = i2_e
    f['background_exc/timeslide_id'] = sid_e
    f['background_exc/template_id'] = tid_e
    f['background_exc/decimation_factor'] = dec_e
    
    maxtime = max(attrs['foreground_time1'], attrs['foreground_time2'])
    mintime = min(attrs['foreground_time1'], attrs['foreground_time2'])
    
    maxtime_exc = maxtime - veto_time
    mintime_exc = mintime - veto_time
    
    background_time = int(maxtime / attrs['timeslide_interval']) * mintime
    coinc_time = float(attrs['coinc_time'])
    
    background_time_exc = int(maxtime_exc / attrs['timeslide_interval']) * mintime_exc
    coinc_time_exc = coinc_time - veto_time
    
    logging.info("Making mapping from FAN to the combined statistic")
    back_stat = stat_i[back_locs]
    fanmap = calculate_fan_map(back_stat, dec_i[back_locs])       
    back_fan = fanmap(back_stat)

    fanmap_exc = calculate_fan_map(stat_e, dec_e)     
    back_fan_exc = fanmap_exc(stat_e)         
    
    f['background/fan'] = back_fan
    f['background/ifar'] = sec_to_year(background_time / back_fan)
    
    f['background_exc/fan'] = back_fan_exc
    f['background_exc/ifar'] = sec_to_year(background_time_exc / back_fan_exc)

    f.attrs['background_time'] = background_time
    f.attrs['foreground_time'] = coinc_time
    f.attrs['background_time_exc'] = background_time_exc
    f.attrs['foreground_time_exc'] = coinc_time_exc
    
    logging.info("calculating ifar values")
    fore_stat = stat_i[fore_locs]
    
    fore_fan = fanmap(fore_stat)
    ifar = background_time / fore_fan
    
    fore_fan_exc = fanmap_exc(fore_stat)
    ifar_exc = background_time_exc / fore_fan_exc

    logging.info("calculating fap values")
    fap = numpy.clip(coinc_time/ifar, 0, 1)
    fap_exc = numpy.clip(coinc_time_exc/ifar_exc, 0, 1)
    if fore_locs.sum() > 0:
        f['foreground/fan'] = fore_fan
        f['foreground/ifar'] = sec_to_year(ifar)
        f['foreground/fap'] = fap
        
        f['foreground/fan_exc'] = fore_fan_exc
        f['foreground/ifar_exc'] = sec_to_year(ifar_exc)
        f['foreground/fap_exc'] = fap_exc
else:
    logging.warn("There were no background events, so we could not assign "
                 "any statistic values")
logging.info("Done") 

