#!/bin/env python
""" Apply a trials factor based on the number of available detector combinations at the
time of coincidence. This clusters to find the most significant foreground, but
leaves the background triggers alone. """

import h5py, numpy, argparse, logging, pycbc, pycbc.events, pycbc.io, lal
import pycbc.version

parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--statmap-files', nargs='+',
                    help="List of coinc files to be redistributed")
parser.add_argument('--cluster-window', type=float)
parser.add_argument('--output-file', help="name of output file")
parser.add_argument('--ifos',nargs='+',
                    help="list of interferometers in the input files" )
args = parser.parse_args()

pycbc.init_logging(args.verbose)

files = [h5py.File(n, 'r') for n in args.statmap_files]

f = h5py.File(args.output_file, "w")

logging.info('Copying segments and attributes to %s' % args.output_file)
# Move segments information into the final file - remove some duplication in earlier files
for fi in files:
    for key in f['segments']:
        f['segments/%s/end' % key] = fi['segments/%s/end' % key][:]
        f['segments/%s/start' % key] = fi['segments/%s/start' % key][:]
        if key.startswith('foreground') or key.startswith('background'):
            continue
        for attr_name in fi.attrs:
            if key not in f:
                 f.create_group(key)
            f[key].attrs[attr_name] = fi.attrs[attr_name]

logging.info('Generating list of datasets in input files')

key_set = pycbc.io.name_all_datasets(files)

logging.info('Copying foreground non-ifo-specific data')
# copy and concatenate all the columns in the foreground group
# from all files /except/ the IFO groups
for key in key_set:
    if key.startswith('foreground') and not any([ifo in key for ifo in args.ifos]):
        pycbc.io.combine_and_copy(f, files, key)

logging.info('Collating triggers into single structure')

all_trig_times = {}
all_trig_ids = {}
for ifo in args.ifos:
    all_trig_times[ifo] = numpy.array([], dtype=numpy.uint32)
    all_trig_ids[ifo] = numpy.array([], dtype=numpy.uint32)

# For each file, append the trigger time and id data for each ifo
# If an ifo does not participate in any given coinc then fill with -1 values
for f_in in files:
    for ifo in args.ifos:
        if ifo in f_in['foreground']:
            all_trig_times[ifo] = numpy.concatenate([all_trig_times[ifo], \
                                     f_in['foreground/{}/time'.format(ifo)][:]])
            all_trig_ids[ifo] = numpy.concatenate([all_trig_ids[ifo],
                                     f_in['foreground/{}/trigger_id'.format(ifo)][:]])
        else:
            all_trig_times[ifo] = numpy.concatenate([all_trig_times[ifo],
                                     -1*numpy.ones_like(f_in['foreground/fap'][:], dtype=numpy.uint32)])
            all_trig_ids[ifo] = numpy.concatenate([all_trig_ids[ifo],
                                     -1*numpy.ones_like(f_in['foreground/fap'][:], dtype=numpy.uint32)])
    f_in.close()

logging.info('Clustering triggers for loudest ifar value')

for ifo in args.ifos:
    f['foreground/{}/time'.format(ifo)] = all_trig_times[ifo]
    f['foreground/{}/trigger_id'.format(ifo)] = all_trig_ids[ifo]

ifar_stat = numpy.core.records.fromarrays([f['foreground/ifar'][:],
                                           f['foreground/stat'][:]],
                                          names='ifar,stat')

# all_times is a tuple of trigger time arrays
all_times = (f['foreground/%s/time' % ifo][:] for ifo in args.ifos)

def argmax(v):
    return numpy.argsort(v)[-1]

# Currently only clustering zerolag, i.e. foreground, so set all timeslide_ids to zero
cidx = pycbc.events.cluster_coincs_multiifo(ifar_stat, all_times,
                                            numpy.zeros(len(ifar_stat)), 0,
                                            args.cluster_window, argmax)

def filter_dataset(h5file, name, idx):
    # Dataset needs to be deleted and remade as it is a different size
    filtered_dset = h5file[name][:][idx]
    del h5file[name]
    h5file[name] = filtered_dset
    return idx

# Downsample the foreground columns to only the loudest ifar between the multiple files
for key in f['foreground'].keys():
    if key not in args.ifos:
        id = filter_dataset(f, 'foreground/%s' % key, cidx)
    else:  # key is an ifo
        for k in f['foreground/%s' % key].keys():
            id = filter_dataset(f, 'foreground/{}/{}'.format(key, k), cidx)

logging.info('Applying trials factor')

# Recalculate event times after clustering for trials factor calculation
clustered_times = (f['foreground/%s/time' % ifo][:] for ifo in args.ifos)

# Trials factor is how many possible 2+IFO combinations are 'on'
#  at the time of the coincidence
trials_factors = numpy.zeros_like(f['foreground/ifar'][:])
test_times = numpy.array([pycbc.events.mean_if_greater_than_zero(tc)[0]
                          for tc in zip(*clustered_times)])
# Iterate over different ifo combinations
for key in f['segments']:
    if key.startswith('foreground') or key.startswith('background'):
        continue
    end_times = numpy.array(f['segments/%s/end' % key][:])
    start_times = numpy.array(f['segments/%s/start' % key][:])
    idx_within_segment = pycbc.events.indices_within_times(test_times, start_times, end_times)
    trials_factors[idx_within_segment] += numpy.ones_like(idx_within_segment)

f['foreground/ifar'][:] = f['foreground/ifar'][:] / trials_factors
f['foreground/ifar_exc'][:] = f['foreground/ifar_exc'][:] / trials_factors

#TODO: work out how to calculate fap given that coinc_time changes for each detector combination
#      commented out parts here need re-working out

#TODO: add in background combinations
# If there is a background set (full_data as opposed to injection run), then recalculate
# the values for its triggers as well

f.close()
logging.info('done')
