#!/usr/bin/python

"""Associate coincident triggers with injections listed in one or more LIGOLW
files.
"""

import argparse, h5py, logging, types, numpy, os.path
from glue.ligolw import ligolw, table, lsctables, utils as ligolw_utils
from ligo import segments
from pycbc import events
from pycbc.events import indices_within_segments as veto_indices
from pycbc.types import MultiDetOptionAction
import pycbc.version

# dummy class needed for loading LIGOLW files
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(LIGOLWContentHandler)

def hdf_append(f, key, value):
    if key in f:
        tmp = numpy.concatenate([f[key][:], value])
        del f[key]
        f[key] = tmp
    else:
        f[key] = value

h5py.File.append = types.MethodType(hdf_append, h5py.File)

def keep_ind(times, start, end):
    """ Return the list of indices within the list of start and end times
    """
    time_sorting = times.argsort()
    times = times[time_sorting]
    indices = numpy.array([], dtype=numpy.uint32)
    right = numpy.searchsorted(times, end, side='right')

    for li, ri in zip(left, right):
        seg_indices = numpy.arange(li, ri, 1).astype(numpy.uint32)
        indices=numpy.union1d(seg_indices, indices)
    return time_sorting[indices]

def xml_to_hdf(table, hdf_file, hdf_key, columns):
    """ Save xml columns as hdf columns, only float32 supported atm.
    """
    for col in columns:
        key = os.path.join(hdf_key, col)
        hdf_append(hdf_file, key, numpy.array(table.get_column(col),
                                        dtype=numpy.float32))

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-files', nargs='+', required=True)
parser.add_argument('--injection-files', nargs='+', required=True)
parser.add_argument('--veto-file')
parser.add_argument('--segment-name', default=None,
                    help='Name of segment list to use for vetoes. Optional')
parser.add_argument('--injection-window', type=float, required=True)
parser.add_argument('--optimal-snr-column', nargs='+',
                    action=MultiDetOptionAction, metavar='DETECTOR:COLUMN',
                    help='Names of the sim_inspiral columns containing the' \
                    ' optimal SNRs.')
parser.add_argument('--redshift-column', default=None,
                    help='Name of sim_inspiral column containing redshift. '
                    'Optional')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file', required=True)
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

fo = h5py.File(args.output_file, 'w')
injection_index = 0
for trigger_file, injection_file in zip(args.trigger_files,
                                        args.injection_files):
    logging.info('Read in the coinc data: %s' % trigger_file)
    f = h5py.File(trigger_file, 'r')

    # Detect whether trigger file is two-ifo or multi-ifo style
    if 'foreground/time1' in f:
        multi_ifo_style = False
        ifo_list = ['H1', 'L1']
    else:
        multi_ifo_style = True
        # Get list of groups which contain subgroup 'time' - these will be the IFOs
        ifo_list = [key for key in f['foreground'] if 'time' in f['foreground/%s/' % key]]
        assert len(ifo_list) > 1

    template_id = f['foreground/template_id'][:]
    stat = f['foreground/stat'][:]
    ifar_exc = f['foreground/ifar_exc'][:]
    ifar = f['foreground/ifar'][:]
    fap = f['foreground/fap'][:]
    fap_exc = f['foreground/fap_exc'][:]
    if multi_ifo_style:
        # using multi-ifo-style trigger file input
        ifo_times = ()
        time_dict = {}
        trig_dict = {}
        for ifo in ifo_list:
             ifo_times += (f['foreground/%s/time' % ifo][:],)
             time_dict[ifo] = f['foreground/%s/time' % ifo][:]
             trig_dict[ifo] = f['foreground/%s/trigger_id' % ifo][:]
        time = numpy.array([events.mean_if_greater_than_zero(vals)[0] 
                                                   for vals in zip(*ifo_times)])
        # We will discard injections which cannot be associated with a
        # coincident event, thus combine segments over all combinations
        # of coincident detectors to determine which times to keep
        any_seg = segments.segmentlist([])
        for key in f['segments']:
            if key == 'foreground':
                continue
            else:
                starts = f['/segments/%s/start' % key][:]
                ends = f['/segments/%s/end' % key][:]
                any_seg += events.start_end_to_segments(starts,ends)
        ana_start, ana_end = events.segments_to_start_end(any_seg)
    else:
        # Using the old 2-ifo organisation
        time1 = f['foreground/time1'][:]
        time2 = f['foreground/time2'][:]
        trig1 = f['foreground/trigger_id1'][:]
        trig2 = f['foreground/trigger_id2'][:]
        ana_start = f['segments/coinc/start'][:]
        ana_end = f['segments/coinc/end'][:]
        time = 0.5 * (time1 + time2)

    time_sorting = time.argsort()

    logging.info('Read in the injection file')
    indoc = ligolw_utils.load_filename(injection_file, False,
                                            contenthandler=LIGOLWContentHandler)
    sim_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)
    inj_time = numpy.array(sim_table.get_column('geocent_end_time')
                           + 1e-9 * sim_table.get_column('geocent_end_time_ns'),
                           dtype=numpy.float64)

    logging.info('Determined the found injections by time')
    left = numpy.searchsorted(time[time_sorting],
                              inj_time - args.injection_window, side='left')
    right = numpy.searchsorted(time[time_sorting],
                               inj_time + args.injection_window, side='right')
    found = numpy.where((right-left) == 1)[0]
    missed = numpy.where((right-left) == 0)[0]
    ambiguous = numpy.where((right-left) > 1)[0]
    missed = numpy.concatenate([missed, ambiguous])
    logging.info('Found: %s, Missed: %s Ambiguous: %s'
                 % (len(found), len(missed), len(ambiguous)))

    if len(ambiguous) > 0:
        logging.warn('More than one coinc trigger found associated to injection')
        am = numpy.arange(0, len(inj_time), 1)[left[ambiguous]]
        bm = numpy.arange(0, len(inj_time), 1)[right[ambiguous]]

    logging.info('Removing injections outside of analyzed time')
    ki = keep_ind(inj_time, ana_start, ana_end)
    found_within_time = numpy.intersect1d(ki, found)
    missed_within_time = numpy.intersect1d(ki, missed)
    logging.info('Found: %s, Missed: %s'
                 % (len(found_within_time), len(missed_within_time)))

    logging.info('Removing injections in vetoed time')

    vi = numpy.array([])
    for ifo in ifo_list:
        vi = numpy.append(vi, veto_indices(inj_time, [args.veto_file],
                                        ifo=ifo, segment_name=args.segment_name))

    found_after_vetoes = numpy.delete(found_within_time,
                                      numpy.where(numpy.in1d(found_within_time, vi))[0])
    missed_after_vetoes = numpy.delete(missed_within_time,
                                       numpy.where(numpy.in1d(missed_within_time, vi))[0])
    logging.info('Found: %s, Missed: %s' % (len(found_after_vetoes),
                                            len(missed_after_vetoes)))

    found_fore = numpy.arange(0, len(stat), 1)[left[found]]
    found_fore_v = numpy.arange(0, len(stat), 1)[left[found_after_vetoes]]

    logging.info('Saving injection information')
    columns = ['mass1', 'mass2', 'spin1x', 'spin1y',
               'spin1z', 'spin2x', 'spin2y', 'spin2z',
               'eff_dist_l', 'eff_dist_h', 'eff_dist_v',
               'inclination', 'polarization', 'coa_phase',
               'latitude', 'longitude', 'distance']
    xml_to_hdf(sim_table, fo, 'injections', columns)
    hdf_append(fo, 'injections/end_time', inj_time)

    # pick up optimal SNRs
    if multi_ifo_style:
        for ifo, column in args.optimal_snr_column.items():
            hdf_append(fo, 'injections/optimal_snr_%s' % ifo,
                       sim_table.get_column(column))
    else:
        ifo_map = {f.attrs['detector_1']: 1,
                   f.attrs['detector_2']: 2}
        for ifo, column in args.optimal_snr_column.items():
            hdf_append(fo, 'injections/optimal_snr_%d' % ifo_map[ifo],
                       sim_table.get_column(column))

    # pick up redshift
    if args.redshift_column:
        hdf_append(fo, 'injections/redshift',
                   sim_table.get_column(args.redshift_column))

    if multi_ifo_style:
        for key in f['segments'].keys():
            if key == 'foreground':
                continue
            if key not in fo:
                 fo.create_group(key)
            fo[key].attrs['pivot'] = f[key].attrs['pivot']
            fo[key].attrs['fixed'] = f[key].attrs['fixed']
            fo[key].attrs['foreground_time_exc'] = f[key].attrs['foreground_time_exc']
    else:
        fo.attrs['detector_1'] = f.attrs['detector_1']
        fo.attrs['detector_2'] = f.attrs['detector_2']
        fo.attrs['foreground_time_exc'] = f.attrs['foreground_time_exc']

    hdf_append(fo, 'missed/all', missed + injection_index)
    hdf_append(fo, 'missed/within_analysis', missed_within_time + injection_index)
    hdf_append(fo, 'missed/after_vetoes', missed_after_vetoes + injection_index)
    hdf_append(fo, 'found/template_id', template_id[time_sorting][found_fore])
    hdf_append(fo, 'found/injection_index', found + injection_index)
    hdf_append(fo, 'found/stat', stat[time_sorting][found_fore])
    hdf_append(fo, 'found/ifar', ifar[time_sorting][found_fore])
    hdf_append(fo, 'found/ifar_exc', ifar_exc[time_sorting][found_fore])
    hdf_append(fo, 'found/fap', fap[time_sorting][found_fore])
    hdf_append(fo, 'found_after_vetoes/template_id',
               template_id[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/injection_index',
               found_after_vetoes + injection_index)
    hdf_append(fo, 'found_after_vetoes/stat', stat[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/ifar', ifar[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/ifar_exc', ifar_exc[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/fap', fap[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/fap_exc', fap_exc[time_sorting][found_fore_v])
    if multi_ifo_style:
        for ifo in ifo_list:
            hdf_append(fo, 'found/%s/time' % ifo,
                       time_dict[ifo][time_sorting][found_fore])
            hdf_append(fo, 'found/%s/trigger_id' % ifo,
                       trig_dict[ifo][time_sorting][found_fore])
            hdf_append(fo, 'found_after_vetoes/%s/time' % ifo,
                       time_dict[ifo][time_sorting][found_fore_v])
            hdf_append(fo, 'found_after_vetoes/%s/trigger_id' % ifo,
                       trig_dict[ifo][time_sorting][found_fore_v])
    else:
        hdf_append(fo, 'found/time1', time1[time_sorting][found_fore])
        hdf_append(fo, 'found/time2', time2[time_sorting][found_fore])
        hdf_append(fo, 'found/trigger_id1', trig1[time_sorting][found_fore])
        hdf_append(fo, 'found/trigger_id2', trig2[time_sorting][found_fore])
        hdf_append(fo, 'found_after_vetoes/time1', time1[time_sorting][found_fore_v])
        hdf_append(fo, 'found_after_vetoes/time2', time2[time_sorting][found_fore_v])
        hdf_append(fo, 'found_after_vetoes/trigger_id1',
                   trig1[time_sorting][found_fore_v])
        hdf_append(fo, 'found_after_vetoes/trigger_id2',
                   trig2[time_sorting][found_fore_v])

    injection_index += len(sim_table)
