#!/bin/env python
""" Apply a trials factor based on the number of available detector combinations at the
time of coincidence. This clusters to find the most significant foreground, but
leaves the background triggers alone. """

import h5py, numpy, argparse, logging, pycbc, pycbc.events, pycbc.io, lal
import pycbc.version

parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--statmap-files', nargs='+',
                    help="List of coinc files to be redistributed")
parser.add_argument('--cluster-window', type=float)
parser.add_argument('--output-file', help="name of output file")
parser.add_argument('--ifos',nargs='+', 
                    help="list of interferometers in the input files" )
args = parser.parse_args()

pycbc.init_logging(args.verbose)

files = [h5py.File(n, 'r') for n in args.statmap_files]

f = h5py.File(args.output_file, "w")

logging.info('Copying segments and attributes to %s' % args.output_file)
# Move segments information into the final file - remove some duplication in earlier files
for fi in files:
    # Get a list of ifos in the segments of the input data file
    ifos_in_file = [ifo for ifo in args.ifos if ifo in fi['segments']]
    # Arbitrarily choose one to copy over the segments for as all are currently the same 
    ifo_to_get = 'segments/{}'.format(ifos_in_file[0])
    # Convert list of IFOs into a identifier string
    ifosstring = ''.join(ifos_in_file)
    group_to_save = '{}/segments'.format(ifosstring)
    f[group_to_save + '/end'] = fi[ifo_to_get + '/end'][:]
    f[group_to_save + '/start'] = fi[ifo_to_get + '/start'][:]
    # Use this ID string to save the attributes from the input files
    for attr_name in fi.attrs:
        f[ifosstring].attrs[attr_name] = fi.attrs[attr_name]

logging.info('Generating list of datasets in input files')

key_set = pycbc.io.name_all_datasets(files)

logging.info('Copying foreground non-ifo-specific data')
# copy and concatenate all the columns in the foreground group 
# from all files /except/ the IFO groups
for key in key_set:
    if key.startswith('foreground') and not any([ifo in key for ifo in args.ifos]):
        pycbc.io.combine_and_copy(f, files, key)

logging.info('Collating triggers into single structure')

all_trig_times = {}
all_trig_ids = {}
for ifo in args.ifos:
    all_trig_times[ifo] = numpy.array([], dtype=numpy.uint32)
    all_trig_ids[ifo] = numpy.array([], dtype=numpy.uint32)

# For each file, append the trigger time and id data for each ifo
# If an ifo does not participate in any given coinc then fill with -1 values
for f_in in files:
    for ifo in args.ifos:
        if ifo in f_in['foreground']:
            all_trig_times[ifo] = numpy.concatenate([all_trig_times[ifo], \
                                     f_in['foreground/{}/time'.format(ifo)][:]])
            all_trig_ids[ifo] = numpy.concatenate([all_trig_ids[ifo], 
                                     f_in['foreground/{}/trigger_id'.format(ifo)][:]])
        else:
            all_trig_times[ifo] = numpy.concatenate([all_trig_times[ifo],
                                     -1*numpy.ones_like(f_in['foreground/fap'][:], dtype=numpy.uint32)])
            all_trig_ids[ifo] = numpy.concatenate([all_trig_ids[ifo],
                                     -1*numpy.ones_like(f_in['foreground/fap'][:], dtype=numpy.uint32)])
    f_in.close()

logging.info('Clustering triggers for loudest ifar value')

for ifo in args.ifos:
    f['foreground/{}/time'.format(ifo)] = all_trig_times[ifo]
    f['foreground/{}/trigger_id'.format(ifo)] = all_trig_ids[ifo]

ifar_stat = numpy.core.records.fromarrays([f['foreground/ifar'][:],
                                           f['foreground/stat'][:]],
                                          names='ifar,stat')

# all_times is a tuple of trigger time arrays
all_times = (f['foreground/%s/time' % ifo][:] for ifo in args.ifos)

def argmax(v):
    return numpy.argsort(v)[-1]

# Currently only clustering zerolag, i.e. foreground, so set all timeslide_ids to zero
cidx = pycbc.events.cluster_coincs_multiifo(ifar_stat, all_times,
                                            numpy.zeros(len(ifar_stat)), 0,
                                            args.cluster_window, argmax)

def filter_dataset(h5file, name, idx):
    # Dataset needs to be deleted and remade as it is a different size
    filtered_dset = h5file[name][:][idx]
    del h5file[name]
    h5file[name] = filtered_dset
    return idx
    
# Downsample the foreground columns to only the loudest ifar between the multiple files
for key in f['foreground'].keys():
    if key not in args.ifos:
        id = filter_dataset(f, 'foreground/%s' % key, cidx)
    else:  # key is an ifo
        for k in f['foreground/%s' % key].keys():
            id = filter_dataset(f, 'foreground/{}/{}'.format(key, k), cidx)

logging.info('Applying trials factor')

# Recalculate event times after clustering for trials factor calculation
clustered_times = (f['foreground/%s/time' % ifo][:] for ifo in args.ifos)

# Trials factor is how many possible 2+IFO combinations are 'on' 
#  at the time of the coincidence
trials_factors = numpy.zeros_like(f['foreground/ifar'][:])
test_times = numpy.array([pycbc.events.mean_if_greater_than_zero(tc)[0] 
                          for tc in zip(*clustered_times)])
# Iterate over different ifo combinations
for key in f:
    if key.startswith('foreground') or key.startswith('background'):
        continue
    end_times = numpy.array(f['%s/segments/end' % key][:])
    start_times = numpy.array(f['%s/segments/start' % key][:])
    idx_within_segment = pycbc.events.indices_within_times(test_times, start_times, end_times)
    trials_factors[idx_within_segment] += numpy.ones_like(idx_within_segment)

f['foreground/ifar'][:] = f['foreground/ifar'][:] / trials_factors
f['foreground/ifar_exc'][:] = f['foreground/ifar_exc'][:] / trials_factors

#TODO: work out how to calculate fap given that coinc_time changes for each detector combination
#      commented out parts here need re-working out

#TODO: add in background combinations
# If there is a background set (full_data as opposed to injection run), then recalculate 
# the values for its triggers as well

f.close()
logging.info('done')
