#!/usr/bin/python
""" This program adds single detector hdf trigger files together.
"""
import numpy, argparse, h5py, logging

def changes(arr):
    from pycbc.future import unique
    l = numpy.where(arr[:-1] != arr[1:])[0]
    l = numpy.concatenate(([0], l+1, [len(arr)]))
    return unique(l)

def collect(key, files):
    data = []
    for fname in files:
        fin = h5py.File(fname)
        if key in fin:
            data += [fin[key][:]]
        fin.close()
    return numpy.concatenate(data)

def region(f, key, boundaries):
    dset = f[key]
    refs = []
    for j in range(len(boundaries) - 1):
        l, r = boundaries[j], boundaries[j+1]
        refs.append(dset.regionref[l:r]) 
    f.create_dataset(key+'_template', data=refs,
                     dtype=h5py.special_dtype(ref=h5py.RegionReference))

parser = argparse.ArgumentParser()
parser.add_argument('--trigger-files', nargs='+')
parser.add_argument('--output-file')
parser.add_argument('--bank-file')
parser.add_argument('--verbose', '-v', action='count')
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.INFO) 

f = h5py.File(args.output_file, 'w')

logging.info("getting the list of columns from a representative file")
trigger_columns = []
for fname in args.trigger_files:
    f2 = h5py.File(fname, 'r')
    ifo = f2.keys()[0]
    if len(f2[ifo].keys()) > 0:
        k = f2[ifo].keys()
        trigger_columns = f2[ifo].keys()
        if not 'template_hash' in trigger_columns:
            f2.close()
            continue
        trigger_columns.remove('search')
        trigger_columns.remove('template_hash')
        f2.close()
        break
    f2.close()

for col in trigger_columns:
    logging.info("trigger column: %s" % col)

logging.info('reading the metadata from the files')
start = numpy.array([], dtype=numpy.float64)
end = numpy.array([], dtype=numpy.float64)
for filename in args.trigger_files:
    data = h5py.File(filename, 'r')
    s, e = data['%s/search/start_time' % ifo][:], data['%s/search/end_time' % ifo][:]
    start, end = numpy.append(start, s), numpy.append(end, e)
    data.close()    
f['%s/search/start_time' % ifo], f['%s/search/end_time' % ifo] = start, end   


logging.info('set up sorting of triggers and template ids')
hashes = h5py.File(args.bank_file, 'r')['template_hash'][:]
trigger_hashes = collect('%s/template_hash' % ifo, args.trigger_files)
trigger_sort = trigger_hashes.argsort()
trigger_hashes = trigger_hashes[trigger_sort]
template_boundaries = changes(trigger_hashes)
template_ids = numpy.searchsorted(hashes, trigger_hashes[template_boundaries[:-1]])

full_boundaries = numpy.searchsorted(trigger_hashes, hashes)
full_boundaries = numpy.concatenate([full_boundaries, [len(trigger_hashes)]])
del trigger_hashes

idlen = (template_boundaries[1:] - template_boundaries[:-1])
f.create_dataset('%s/template_id' % ifo, data=numpy.repeat(template_ids, idlen), 
                 compression='gzip', shuffle=True)
f['%s/template_boundaries' % ifo] = full_boundaries 

logging.info('reading the trigger columns from the input files')
for col in trigger_columns:
    key = '%s/%s' % (ifo, col)
    logging.info('reading %s' % col)
    data = collect(key, args.trigger_files)[trigger_sort]
    logging.info('writing %s to file' % col)
    dset = f.create_dataset(key, data=data, compression='gzip')
    del data
    region(f, key, full_boundaries) 
f.close()
logging.info('done')
