#!/usr/bin/env python

import sys
import os
import bisect
from collections import defaultdict
from optparse import OptionParser
import itertools
import hashlib

import numpy
import scipy.optimize

from glue.ligolw import lsctables, table, utils, ligolw, ilwd
table.use_in(ligolw.LIGOLWContentHandler)
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue import lal as gluelal
from glue.ligolw.utils import process

from pylal import antenna
#from pylal import coherent_inspiral_metric as metric
#from pylal import coherent_inspiral_metric_detector_details as details
from pylal.series import read_psd_xmldoc, LIGOLWContentHandler

import lal
import lalburst
import lalsimulation
import lalmetaio
import lalinspiral

#
# Utility functions
#

def copy_sim_inspiral( row ):
    """
    Turn a lsctables.SimInspiral into a SWIG wrapped lalburst.SimInspiral
    """
    swigrow = lalmetaio.SimInspiralTable()
    for simattr in lsctables.SimInspiralTable.validcolumns.keys():
        if simattr in ["waveform", "source", "numrel_data", "taper"]:
            # unicode -> char* doesn't work
            setattr( swigrow, simattr, str(getattr(row, simattr)) )
        else:
            setattr( swigrow, simattr, getattr(row, simattr) )
    # FIXME: This doesn't get copied properly, and so is done manually here.
    swigrow.geocent_end_time = lal.LIGOTimeGPS(row.geocent_end_time, row.geocent_end_time_ns)
    return swigrow

def copy_sim_inspiral_swig( swigrow, table ):
    """ 
    Turn a SWIG wrapped lalburst.SimInspiral into a lsctables.SimInspiral 
    """

    row = table.RowType()

    for simattr in lsctables.SimInspiralTable.validcolumns.keys():
        if simattr in ["waveform", "source", "numrel_data", "taper"]:
            # unicode -> char* doesn't work
            setattr( row, simattr, str(getattr(swigrow, simattr)) )
        else:
            setattr( row, simattr, getattr(swigrow, simattr) )

    return row

def copy_sim_burst_swig( swigrow, table ):
    """
    Turn a SWIG wrapped lalburst.SimBurst into a lsctables.SimBurst
    """

    row = table.RowType()
    row.hrss = swigrow.hrss
    row.time_geocent_gps = gluelal.LIGOTimeGPS(float(swigrow.time_geocent_gps)).seconds
    row.time_geocent_gps_ns = gluelal.LIGOTimeGPS(float(swigrow.time_geocent_gps)).nanoseconds
    row.psi = swigrow.psi
    row.amplitude = swigrow.amplitude or 0
    row.egw_over_rsquared = swigrow.egw_over_rsquared or 0
    row.waveform_number = swigrow.waveform_number or 0
    row.pol_ellipse_angle = swigrow.pol_ellipse_angle or 0.0
    row.simulation_id = ilwd.ilwdchar("sim_burst:simulation_id:%d" % swigrow.simulation_id)
    row.q = swigrow.q
    row.waveform = swigrow.waveform
    row.bandwidth = swigrow.bandwidth
    row.process_id = ilwd.ilwdchar("process:process_id:%d" % swigrow.process_id)
    row.time_slide_id = ilwd.ilwdchar("time_slide:time_slide_id:0")
    row.frequency = swigrow.frequency
    row.ra = swigrow.ra
    row.time_geocent_gmst = swigrow.time_geocent_gmst
    row.pol_ellipse_e = swigrow.pol_ellipse_e
    row.duration = swigrow.duration
    row.dec = swigrow.dec

    return row

def copy_sim_burst( row ):
    """
    Turn a lsctables.SimBurst into a SWIG wrapped lalburst.SimBurst
    """

    swigrow = lalburst.CreateSimBurst()
    swigrow.hrss = row.hrss or 0
    time = lal.LIGOTimeGPS( row.time_geocent_gps, row.time_geocent_gps_ns )
    swigrow.time_geocent_gps = time
    swigrow.psi = row.psi
    swigrow.amplitude = row.amplitude or 0
    swigrow.egw_over_rsquared = row.egw_over_rsquared or 0
    swigrow.waveform_number = row.waveform_number or 0
    swigrow.pol_ellipse_angle = row.pol_ellipse_angle or 0.0
    swigrow.simulation_id = int(row.simulation_id)
    swigrow.q = row.q or 0.0
    swigrow.waveform = str(row.waveform)
    swigrow.bandwidth = row.bandwidth or 0.0
    swigrow.process_id = int(row.process_id)
    swigrow.frequency = row.frequency or 0.0
    swigrow.ra = row.ra
    swigrow.time_geocent_gmst = row.time_geocent_gmst or 0
    swigrow.pol_ellipse_e = row.pol_ellipse_e or 0.0
    swigrow.duration = row.duration or 0.0
    swigrow.dec = row.dec

    return swigrow

def parse_psd_file( filestr ):
    """
    Map the user-provided PSD file string into a function to be called as PSD(f).
    """
    try:
        xmldoc = utils.load_filename(filestr, contenthandler=LIGOLWContentHandler)
        psd = read_psd_xmldoc(xmldoc).values()[0]
        f = numpy.arange(0, len(psd.data)*psd.deltaF, psd.deltaF)
        psd = psd.data
    except:
        f, psd = numpy.loadtxt( filestr, unpack=True )
    def anon_interp( fvals ):
        return numpy.interp( fvals, f, psd )
    return anon_interp

def parse_psd_func( funcstr ):
    """
    Map the user-provided PSD function string to its equivalent function in lalsimulation.
    """
    try:
        return getattr( lalsimulation, funcstr )
    except AttributeError:
        raise AttributeError( "Could not find PSD function %s in lalsimulation" % funcstr )

def sim_to_sngl_table(sim_type, table_type):
    column_over = list(set(table_type.validcolumns) & set(sim_type.validcolumns))
    if sim_type == lsctables.SimInspiralTable:
        return lsctables.New(table_type, column_over + ["search", "channel", "eff_distance", "ifo", "event_id", "end_time", "end_time_ns", "snr"])
    elif sim_type == lsctables.SimBurstTable:
        return lsctables.New(table_type, column_over + ["search", "channel", "central_freq", "ifo", "event_id", "start_time", "start_time_ns", "peak_time", "peak_time_ns", "snr"])

def sim_to_sngl_insp(sim, ifo, sngl_table):
    """
    Take the information in sim and transfer overlapping information to a SnglInspiral row, thus appended to sngl_table.
    """
    if "SimInspiral" in str(sim): # Don't ask
        column_over = list(set(sngl_table.validcolumns) & set(lsctables.SimInspiralTable.validcolumns))
    else:
        column_over = list(set(sngl_table.validcolumns) & set(lsctables.SimBurstTable.validcolumns))
    sngl_row = sngl_table.RowType()
    for col in column_over:
        setattr(sngl_row, col, getattr(sim, col))

    sngl_row.set_end( gluelal.LIGOTimeGPS(
            int(getattr(sim, "%s_end_time" % ifo[0].lower())),
            getattr(sim, "%s_end_time_ns" % ifo[0].lower()) ) )
    sngl_row.ifo = ifo
    sngl_row.search = "fake"
    sngl_row.channel = "fake"
    sngl_row.process_id = ilwd.ilwdchar("process:process_id:0")
    if isinstance(sim, lsctables.SimInspiral):
        sngl_row.eff_distance = getattr(row, "eff_dist_%s" % ifo[0].lower())
    sngl_table.append(sngl_row)
    return sngl_row

def sim_to_sngl_burst(sim, ifo, sngl_table):
    """
    Take the information in sim and transfer overlapping information to a SnglBurst row, thus appended to sngl_table.
    """
    if "SimInspiral" in str(sim): # Don't ask
        column_over = list(set(sngl_table.validcolumns) & set(lsctables.SimInspiralTable.validcolumns))
    else:
        column_over = list(set(sngl_table.validcolumns) & set(lsctables.SimBurstTable.validcolumns))
    sngl_row = sngl_table.RowType()
    for col in column_over:
        setattr(sngl_row, col, getattr(sim, col))

    sngl_row.set_peak(gluelal.LIGOTimeGPS( sim.time_geocent_gps.gpsSeconds, sim.time_geocent_gps.gpsNanoSeconds))
    if sim.duration is not None:
        # Likely BTLWNB
        sngl_row.set_start(sngl_row.get_peak() - sim.duration/2.0)
    else:
        # Likely CSG
        sngl_row.set_start(sngl_row.get_peak() - sim.q/(numpy.sqrt(2)*numpy.pi*sim.frequency)/2.0)
    if sim.bandwidth is not None:
        sngl_row.central_freq = sim.frequency + sim.bandwidth/2.0
    else:
        sngl_row.central_freq = sim.frequency
    sngl_row.ifo = ifo
    sngl_row.search = "fake"
    sngl_row.channel = "fake"
    sngl_row.process_id = ilwd.ilwdchar("process:process_id:0")
    sngl_table.append(sngl_row)
    return sngl_row

def hoff_single_ifo_snr(hf, psd=None, f_low=0, f_high=None):
    """
    Determine the single IFO SNR of a frequency domain waveform hf, with a given psd. If no psd is supplied (None), then return the hrss. Integration will take place over f_low (0 by default) to f_high (likely nyquist).
    """
    hrss = hf.data.data.conj() * hf.data.data
    if psd is None:
        return sqrt(hrss.sum()*hf.deltaF)
    # Determine limits of integration
    #print "Low frequency of integration %f" % int_f_low_bin
    int_f_low_bin = int(numpy.round(f_low/hf.deltaF))
    #print "High frequency of integration %f" % int_f_high_bin
    if f_high is None:
        # FIXME: Assumes the hf, psd have the same deltaF and f0
        int_f_high_bin = min(len(hf.data.data), len(psd.data.data))
    else:
        int_f_high_bin = int(numpy.round(f_high/hf.deltaF))
    # divide and sum for SNR
    nw_ip = ( hrss / psd )[int_f_low_bin:int_f_high_bin]
    power = 4 * numpy.real(sum(nw_ip)) * hf.deltaF
    return numpy.sqrt(power)

def rescale_injection_uniform_snr(inj_row, snr, snr_scale):
    """
    Reassign the injection SNR to be a random variate in the interval (tuple) provided in snr_scale. See rescale_injection_snr for details on how this is accomplished. Returns the new SNR chosen.
    """
    scale_abs = abs(snr_scale[1] - snr_scale[0])
    new_snr = numpy.random.random() * scale_abs + snr_scale[0]
    rescale_injection_snr(inj_row, snr, new_snr)
    return new_snr

def rescale_injection_normal_snr(inj_row, snr, mean, std):
    """
    Reassign the injection SNR to be a random variate drawn from a normal distribution centered on mean with a given variance. See rescale_injection_snr for details on how this is accomplished. Returns the new SNR chosen.
    """
    new_snr = numpy.random.normal(mean, std)
    rescale_injection_snr(inj_row, snr, new_snr)
    return new_snr

def rescale_injection_snr(inj_row, snr, new_snr):
    """
    Rescale a parameter in the inj_row to match a new_snr value. The value of the new_snr argument will scale the injection to the provided value.
    """

    snr_ratio = new_snr / snr

    if isinstance(inj_row, lsctables.SimInspiral) or isinstance(inj_row, lalmetaio.SimInspiralTable):
        inj_row.distance /= snr_ratio
    elif inj_row.waveform == "BTLWNB":
        inj_row.egw_over_rsquared *= snr_ratio**2
        inj_row.hrss *= snr_ratio
    elif inj_row.waveform == "SineGaussian":
        inj_row.hrss *= snr_ratio
    elif inj_row.waveform == "StringCusp":
        # CHECK: Is a string cusp injection SNR directly proportional to its
        # amplitude?
        inj_row.ampltude *= snr_ratio
    elif inj_row.waveform == "Impulse":
        inj_row.ampltude *= snr_ratio
        inj_row.hrss *= snr_ratio
    else:
        if hasattr(inj_row, "waveform"):
            raise ValueError("Don't know how to scale burst waveform %s" % inj_row.waveform)
        else:
            raise ValueError("Don't know how to scale %s" % str(type(inj_row)))

#
# Efficiency calculation routines
#

def loglogistic( x, alpha, beta ):
    """
    The log-logistic CDF distribution with parameters alpha (scale), beta (dispersion).
    """
    return 1.0/(1+(x/alpha)**(-beta))

def loglogisticinv( y, alpha, beta ):
    """
    The inverse of the log-logistic CDF distribution.
    """
    return alpha*(1.0/y - 1)**(-1.0/beta)

def get_eff_params( xbins, ybins, alpha0=5e-24, beta0=1 ):
    """
    Fit the data to a loglogistic CDF curve to get the parameters for efficiency values.
    """
    params, corr = scipy.optimize.curve_fit( loglogistic, xbins, ybins, (alpha0, beta0) )
    return params

def determine_10_50_90_eff( params ):
    """
    Use fit parameters to invert the efficiency function and figure out the statistic values at 10%, 50%, and 90%.
    """
    a, b = params
    return (loglogisticinv(0.1, a, b), loglogisticinv(0.5, a, b), loglogisticinv(0.9, a, b))

def rescale_number( x, orig_intv, new_intv ):
    """
    Rescale a number x which lines within the original interval orig_intv such that it occupies the same ``rank'' within the new interval new_intv.
    TODO: Allow for a two-sided interval such that the rescaling is about the mean.
    """
    oint = orig_intv[1] - orig_intv[0]
    nint = new_intv[1] - new_intv[0]
    return ( (x - orig_intv[0])/oint * nint ) + new_intv[0]

optp = OptionParser()
optp.add_option("--det-psd-func", action="append", help="Set a detector to use in the network with the corresponding PSD function from lalsimulation. Example: --det-psd-func H1=SimNoisePSDaLIGOZeroDetHighPower. Can be set multiple times to add more detectors to the network.")
optp.add_option("--waveform-length", type=float, default=32.0, help="Length of waveform in seconds. Note that this affects the binning of the PSD. Default is 32 s")
optp.add_option("--nyquist-frequency", type=float, default=4096.0, help="Nyquist frequency of the PSD to generate. Default is 4096 Hz.")
optp.add_option("--low-frequency-cutoff", type=float, default=0.0, help="Frequency at which to start the integration. The max of the waveform f_lower or this option is used. Default is 0.")
optp.add_option("--high-frequency-cutoff", type=float, help="Frequency at which to end the integration. The min of the nyquist frequency or this option is used. Default is 0.")

optp.add_option("--snr-thresh", type=float, default=8.0, help="SNR threshold (network) to use for efficiency measurement. Default is 8.")
optp.add_option("--eff-attribute", default="hrss", help="SimBurst attribute against which to measure efficiency. Default is 'hrss', valid choices are any SimBurst attribute.")
optp.add_option("--eff-bins", default=20, type=int, help="Number of bins over which to measure efficiency. Default is 20.")
optp.add_option("--eff-rank", default="netSNR", help="How to rank injections for efficiency purposes. Valid choices are ``netSNR'' (network SNR), ``cohSNR'' (coherent SNR), and ``eta'' (cWB network amplitude). Default is netSNR.")

optp.add_option("--calculate-only", action="store_true", help="Do not perform any injection cutting or squeezing, calculate values only.")
optp.add_option("--machine-parse", action="store_true", help="Print in a machine parseable way.")
optp.add_option("--no-print-single-ifo", action="store_true", default=False, help="Don't print individual IFO SNRs.")
optp.add_option("--no-print-network", action="store_true", default=False, help="Don't print network SNRs.")
optp.add_option("--no-print-coherent", action="store_true", default=False, help="Don't print coherent SNRs.")
optp.add_option("--no-print-cwb", action="store_true", default=False, help="Don't print CWB ampltiudes.")
optp.add_option("--verbose", action="store_true", help="Be verbose.")

optp.add_option("--single-snr-cut-thresh", type=float, help="Cut any injection with a single detector SNR less than this.")
optp.add_option("--coherent-snr-cut-thresh", type=float, help="Cut any injection with a coherent SNR less than this.")
optp.add_option("--coherent-amplitude-cut-thresh", type=float, help="Cut any injection with a cWB coherent ampltiude (eta) less than this.")
optp.add_option("--sim-id", type=int, help="Only calculate for simulation_id N")
optp.add_option("--store-sngl", help="Store the values of the simulated event and SNR calculation in a sngl_* event row. Valid choices are sngl_inspiral, sngl_burst.")
optp.add_option("--store-sim", action="store_true", help="Store the value of the SNR calculation in a sim_burst event row in the column ``amplitude''. This is only valid for sim_bursts.")

optp.add_option("--squeeze-attribute", default="hrss", help="Squeeze this attribute for injections so as to make the more detectable. Default is 'hrss'. You might try 'distance' for SimInspirals.")

optp.add_option("--rescale-snr-uniform-min", type=float, help="Rescale all SNRs to be uniform deviates within range specified by this and --rescale-snr-uniform-max")
optp.add_option("--rescale-snr-uniform-max", type=float, help="Rescale all SNRs to be uniform deviates within range specified by this and --rescale-snr-uniform-min")
optp.add_option("--rescale-snr-normal-mean", type=float, help="Rescale all SNRs to be normal deviates with parameters specified by this and --rescale-snr-normal-stddev")
optp.add_option("--rescale-snr-normal-stddev", type=float, help="Rescale all SNRs to be normal deviates with parameters specified by this and --rescale-snr-normal-mean")
optp.add_option("--rescale-snr-network", action="store_true", help="Use network SNR to rescale the injection strength.")

opts, args = optp.parse_args()

#
# Sanity checks
#
uniform_scale, uniform_scale_network, norm_params, norm_params_network = None, None, None, None
if opts.rescale_snr_uniform_min and not opts.rescale_snr_uniform_max:
    sys.exit("Uniform rescaling requested, but bounds not completely specified.")
elif opts.rescale_snr_network and opts.rescale_snr_uniform_min and opts.rescale_snr_uniform_max:
    uniform_scale_network = (opts.rescale_snr_uniform_min, opts.rescale_snr_uniform_max)
elif opts.rescale_snr_uniform_min and opts.rescale_snr_uniform_max:
    uniform_scale = (opts.rescale_snr_uniform_min, opts.rescale_snr_uniform_max)

if opts.rescale_snr_normal_mean and not opts.rescale_snr_normal_stddev:
    sys.exit("Normal rescaling requested, but parameters not completely specified.")
elif opts.rescale_snr_network and opts.rescale_snr_normal_mean and opts.rescale_snr_normal_stddev:
    norm_params_network = (opts.rescale_snr_normal_mean, opts.rescale_snr_normal_stddev)
elif opts.rescale_snr_normal_mean and opts.rescale_snr_normal_stddev:
    norm_params = (opts.rescale_snr_normal_mean, opts.rescale_snr_normal_stddev)

if uniform_scale is not None and norm_params is not None:
    sys.exit("Two rescaling strategy parameter sets requested, please choose either --rescale-snr-normal-* or --rescale-snr-uniform-*")

#
# Detector, PSD setup
#

detectors = {"H1": lal.CachedDetectors[lal.LHO_4K_DETECTOR],
             "L1": lal.CachedDetectors[lal.LLO_4K_DETECTOR],
             "V1": lal.CachedDetectors[lal.VIRGO_DETECTOR]}
uniq_det_comb = list( itertools.combinations( detectors.keys(), 2 ) )

# Default is R=8192, and data with a length of 32 seconds. Sorta excessive.
fnyq = opts.nyquist_frequency
deltaF = 1.0/opts.waveform_length
eff_attr = opts.eff_attribute

f = numpy.linspace(0, int(fnyq/deltaF), int(fnyq/deltaF)+1)*deltaF
detector_psds = {}
if opts.det_psd_func is None:
    raise ArgumentError("Need at least one detector / PSD specified.")
for dspec in opts.det_psd_func:
    det, psdfunc = dspec.split("=")
    # Secondary mapping of psdfuncs to actual numerical arrays
    if os.path.isfile( psdfunc ):
        if opts.verbose:
            print "Detector %s will use %s as the PSD" % (det, psdfunc)
        detector_psds[det] = numpy.array(parse_psd_file( psdfunc )(f))
    else:
        if opts.verbose:
            print "Looking up function %s in lalsimulation for use with %s as the PSD" % (psdfunc, det)
        detector_psds[det] = numpy.array(map(parse_psd_func( psdfunc ), f))
    # Set DC sensitivity to something sensible, since it often comes out nan
    detector_psds[det][0] = detector_psds[det][1]

if len(detector_psds) == 0:
    raise ArgumentError("Need at least one detector / PSD specified.")

for d in detectors.keys():
    if d not in detector_psds.keys():
        del detectors[d]

#
# Generate required FFT plans
#

# samples = rate * length
fwdlen = int(2.0*fnyq/deltaF)
fwdplan = lal.CreateForwardREAL8FFTPlan( fwdlen, 0 )

#
# Get XML document contents
#
xmldoc = utils.load_filename(args[0], contenthandler=ligolw.LIGOLWContentHandler)

#
# Initialize random seed
# This is done so as to ensure the same random SNRs are drawn each time for any
# given file
#
with open(args[0]) as xmlfile:
    h = hashlib.new("md5", xmlfile.read())
    numpy.random.seed(int(h.hexdigest()[:8], 16))

#
# Get sim tables
#
sims = []
sim_burst = []
try:
    sim_burst = table.get_table( xmldoc, lsctables.SimBurstTable.tableName )
    sims.extend( map(copy_sim_burst, sim_burst) )
except ValueError:
    if opts.verbose:
        print >>sys.stderr, "No SimBurst table found, skipping..."
sim_insp = []
try:
    sim_insp = table.get_table( xmldoc, lsctables.SimInspiralTable.tableName )
    sims.extend( map(copy_sim_inspiral, sim_insp) )
except ValueError:
    if opts.verbose:
        print >>sys.stderr, "No SimInspiral table found, skipping..."

snr_thresh = opts.snr_thresh
injected, efficiencies, count = defaultdict(list), defaultdict(list), defaultdict(lambda: 0)

# Store injection waveform temporarily
injection_waveforms = {}

# Store sims to cut
cut_sims = []

# Store realized SNRs in Sngl* rows
if opts.store_sngl == "sngl_inspiral":
    sngl_table = sim_to_sngl_table(lsctables.SimInspiralTable, lsctables.SnglInspiralTable)
elif opts.store_sngl == "sngl_burst":
    sngl_table = sim_to_sngl_table(lsctables.SimBurstTable, lsctables.SnglBurstTable)

sim_snr_table = lsctables.New(lsctables.SimBurstTable)

i = 0
for row in sims:
    i += 1
    snr = {}

    if opts.sim_id is not None and opts.sim_id != int(row.simulation_id):
        continue

    for d, det in detectors.iteritems():
        psd = detector_psds[d]

        # inspiral case
        if type(row) == lalmetaio.SimInspiralTable:
            time = row.geocent_end_time
            hp, hx = lalinspiral.SimInspiralChooseWaveformFromSimInspiral( row, 0.5/fnyq )
            hp.epoch += time
            hx.epoch += time
            h = lalsimulation.SimDetectorStrainREAL8TimeSeries( hp, hx,
                    row.longitude, row.latitude, row.polarization, det )
            ra, dec, psi = row.longitude, row.latitude, row.polarization
            f_low_wave = max(opts.low_frequency_cutoff, row.f_lower)
        # burst case
        elif type(row) == lalmetaio.SimBurst:
            time = row.time_geocent_gps
            hp, hx = lalburst.GenerateSimBurst( row, 0.5/fnyq )
            h = lalsimulation.SimDetectorStrainREAL8TimeSeries( hp, hx,
                    row.ra, row.dec, row.psi, det )
            ra, dec, psi = row.ra, row.dec, row.psi
            # FIXME: What's the f_lower for a burst? central_freq - 2*band?
            # more over, that might not make sense for things like cosmic
            # strings
            f_low_wave = max(opts.low_frequency_cutoff, 0)
        else:
            raise TypeError("%s: It's not a burst, and it's not an inspiral... what do you want me to do here?" % type(row))

        # zero pad
        needed_samps = int(2.0*fnyq/deltaF)
        prevlen = h.data.length
        if h.data.length < needed_samps:
            h = lal.ResizeREAL8TimeSeries( h, 0, needed_samps )
        elif h.data.length > needed_samps:
            h = lal.ResizeREAL8TimeSeries( h, h.data.length-needed_samps, needed_samps )

        # Forward FFT
        # adjust heterodyne frequency to match flow
        hf = lal.CreateCOMPLEX16FrequencySeries(
                name = "FD signal",
                epoch = h.epoch,
                f0 = 0,
                deltaF = deltaF,
                sampleUnits = lal.DimensionlessUnit,
                length = h.data.length/2 + 1
        )

        lal.REAL8TimeFreqFFT(hf, h, fwdplan)
        injection_waveforms[d] = hf

        snr[d] = hoff_single_ifo_snr(hf, psd, f_low_wave, min(opts.high_frequency_cutoff or float("inf"), opts.nyquist_frequency))

        #
        # Reassign SNRs to be within a certain range
        #
        old_snr = snr[d]
        if norm_params is not None:
            snr[d] = rescale_injection_normal_snr(row, snr[d], norm_params[0], norm_params[1])
        elif uniform_scale is not None:
            snr[d] = rescale_injection_uniform_snr(row, snr[d], uniform_scale)
        hf.data.data *= snr[d]/old_snr

        #
        # Print out some information
        #
        if opts.machine_parse and not opts.no_print_single_ifo:
            print snr[d]
        elif not opts.no_print_single_ifo:
            print "Waveform %s at %10.3f has SNR in %s of %f" % (row.waveform, time, d, snr[d])

        #
        # Make 'sngl_' rows with the optimal SNR information
        #
        if opts.store_sngl == "sngl_inspiral":
            sngl_row = sim_to_sngl_insp(row, det.frDetector.prefix, sngl_table)
            sngl_row.snr = snr[d]
            sngl_row.event_id = sngl_table.get_next_id()
        elif opts.store_sngl == "sngl_burst":
            sngl_row = sim_to_sngl_burst(row, det.frDetector.prefix, sngl_table)
            sngl_row.snr = snr[d]**2 # SNR column is energy, really
            sngl_row.event_id = sngl_table.get_next_id()

    #
    # Network SNR
    #
    net_snr = numpy.sqrt(sum([ s**2 for s in snr.values()]))
    if opts.store_sim:
        row.amplitude = net_snr
        sim_snr_table.append(row)

    if opts.machine_parse and not opts.no_print_network:
        print "%f" % net_snr
    elif not opts.no_print_network:
        print "Network SNR for %s at %10.3f is %f" % (row.waveform, time, net_snr)

    #
    # Coherent SNR
    #
    time_delays = {}
    responses = {}
    d0 = "H1"
    for d in detectors.keys():
        time_delays[(d0,d)] = antenna.timeDelay( float(time), ra, dec, 'radians', d0, d )
    responses = {}
    # TODO: move this up to the PSD section, since it's only a function of frequency
    S_f_coh = numpy.zeros( (2, len(f)) )
    hp_coh = numpy.zeros( len(f), dtype=numpy.complex128 )
    hx_coh = numpy.zeros( len(f), dtype=numpy.complex128 )
    hp_coh_det = {}
    hx_coh_det = {}
    for d in detectors.keys():
        responses[d] = antenna.response( float(time), ra, dec, 0, psi, 'radians', d )[:2]
        # Creighton and Anderson eq. 7.137
        S_f_coh += numpy.array([ numpy.conj(r)*r/detector_psds[d] for r in responses[d] ])

        # Creighton and Anderson eq. 7.136a (numerator)
        hp_coh_det[d] = injection_waveforms[d].data.data * numpy.exp( 2*numpy.pi * 1j * f * time_delays[(d0,d)] ) * responses[d][0] / detector_psds[d]
        hp_coh += hp_coh_det[d]
        # Creighton and Anderson eq. 7.136b (numerator)
        hx_coh_det[d] = injection_waveforms[d].data.data * numpy.exp( 2*numpy.pi * 1j * f * time_delays[(d0,d)] ) * responses[d][1] / detector_psds[d]
        hx_coh += hx_coh_det[d]
        df = injection_waveforms[d].deltaF

    # Creighton and Anderson eq. 7.136a,b (full)
    hp_coh_f = hp_coh / S_f_coh[0]
    hx_coh_f = hx_coh / S_f_coh[1]

    # Creighton and Anderson eq. 7.138
    coh_snr = 4 * numpy.real(sum(hp_coh_f.conj() * hp_coh_f * S_f_coh[0])) * df
    coh_snr += 4 * numpy.real(sum(hx_coh_f.conj() * hx_coh_f * S_f_coh[1])) * df
    if opts.machine_parse and not opts.no_print_coherent:
        print "%g" % numpy.sqrt(coh_snr)
    elif not opts.no_print_coherent:
        print "Coherent SNR: %g" % numpy.sqrt(coh_snr)

    coh_energy = 0
    coh_energy_mat = {}
    for ((d1, hf_d1), (d2, hf_d2)) in itertools.combinations( tuple(hp_coh_det.iteritems()), 2 ):
        coh_energy_mat[(d1,d2)] = sum(hf_d1.conj() * hf_d2 / S_f_coh[0])*df
        coh_energy += coh_energy_mat[(d1,d2)]

    eta_est = numpy.sqrt(abs(coh_energy)/len(detectors.values()))
    if opts.machine_parse and not opts.no_print_cwb:
        print "%g" % eta_est
    elif not opts.no_print_cwb:
        print "Estimated cWB ranking statistic (1G) %g" % eta_est

    #
    # Reassign SNRs to be within a certain range -- network version
    #
    if norm_params_network is not None:
        rescale_injection_normal_snr(row, net_snr, norm_params_network[0], norm_params_network[1])
    elif uniform_scale_network is not None:
        rescale_injection_uniform_snr(row, net_snr, uniform_scale_network)

    if opts.single_snr_cut_thresh is not None and any(map(lambda s: s < opts.single_snr_cut_thresh, snr.values())):
        cut_sims.append( (type(row), row.simulation_id ) )
        continue
    if opts.coherent_snr_cut_thresh is not None and coherent_snr < opts.coherent_snr_cut_thresh:
        cut_sims.append( (type(row), row.simulation_id ) )
        continue
    if opts.coherent_amplitude_cut_thresh is not None and eta_est < opts.coherent_amplitude_cut_thresh:
        cut_sims.append( (type(row), row.simulation_id ) )
        continue

    # TODO: Use different ranking statistic for efficiency
    if net_snr > snr_thresh:
        efficiencies[row.waveform].append( row )

    injected[row.waveform].append( row )
    count[row.waveform] += 1

#
# Output rescaled SNRs
#
if norm_params_network is not None or norm_params is not None or uniform_scale is not None:
    xmldoc = ligolw.Document()
    xmldoc.appendChild(ligolw.LIGO_LW())
    if sim_burst:
        simb_snr_table = lsctables.New(lsctables.SimBurstTable)
    if sim_insp:
        simi_snr_table = lsctables.New(lsctables.SimInspiralTable)
    for sim in sims:
        if isinstance(sim, lsctables.SimInspiral):
            simi_snr_table.append(sim)
        elif isinstance(sim, lalmetaio.SimInspiralTable):
            simi_snr_table.append(copy_sim_inspiral_swig(sim, simi_snr_table))
        elif isinstance(sim, lsctables.SimBurst):
            simb_snr_table.append(sim)
        elif isinstance(sim, lalmetaio.SimBurst):
            simb_snr_table.append(copy_sim_burst_swig(sim, simb_snr_table))

    if sim_burst and simb_snr_table:
        xmldoc.childNodes[0].appendChild(simb_snr_table)
    if sim_insp and simi_snr_table:
        xmldoc.childNodes[0].appendChild(simi_snr_table)

    utils.write_filename(xmldoc, "snr_rescaled.xml.gz", gz=True)

#
# Identify injections not passing the SNR thresholds
#
try:
    sim_burst = table.get_table(xmldoc, lsctables.SimBurstTable.tableName)
    sb_ids = map(lambda sb: sb[1], filter(lambda cut: cut[0] == lalmetaio.SimBurst, cut_sims))
    print "%d burst injections would be cut from SNR thresholds." % len(sb_ids)
    n = len(sim_burst)-1
    for i, sb in enumerate(sim_burst[::-1]):
        if int(sb.simulation_id) in sb_ids:
            del sim_burst[n-i]
except ValueError:
    pass
try:
    sim_insp = table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)
    si_ids = map(lambda sb: sb[1], filter(lambda cut: cut[0] == lalmetaio.SimInspiralTable, cut_sims))
    print "%d inspiral injections would be cut from SNR thresholds." % len(si_ids)
    n = len(sim_insp)-1
    for i, si in enumerate(sim_insp[::-1]):
        if int(si.simulation_id) in si_ids:
            del sim_insp[n-i]
except ValueError:
    pass

#
# Record "triggers" with optimal information
#
if opts.store_sngl is not None or opts.store_sim:
    xmldoc = ligolw.Document()
    xmldoc.childNodes.append(ligolw.LIGO_LW())
    if opts.store_sngl is not None:
        xmldoc.childNodes[0].appendChild(sngl_table)
        ifos = reduce(str.__add__, [ifo[0] for ifo in detectors.keys()])

    if opts.store_sim:
        for i, row in enumerate(sim_snr_table):
            sim_snr_table[i] = copy_sim_burst_swig(row, sim_snr_table)
        xmldoc.childNodes[0].appendChild(sim_snr_table)
        ifos = reduce(str.__add__, [ifo[0] for ifo in detector_psds.keys()])

    process.register_to_xmldoc(xmldoc, sys.argv[0], opts.__dict__)
    utils.write_filename(xmldoc, "%s-FAKE_SEARCH.xml.gz" % ifos, gz=True)

if opts.calculate_only:
    exit(0)

if len(cut_sims) > 0:
    print "Writing cut XML file."
    utils.write_filename( xmldoc, "cut_tmp.xml" )

if opts.sim_id is not None:
    exit()

from pylal import imr_utils
from pylal import rate

for wave, eff in efficiencies.iteritems():
    ndet = len(eff)
    if opts.verbose:
        print "Report for waveform family %s" % wave
        print "Efficiency: %d / %d = %f" % (ndet, count[wave], float(ndet)/count[wave])

    # Number of injections for this type
    inj = injected[wave]

    #
    # Measure efficiency
    #
    effbins = imr_utils.guess_nd_bins( inj, {eff_attr: (opts.eff_bins, rate.LogarithmicBins)} )
    eff, err = imr_utils.compute_search_efficiency_in_bins( eff, inj, effbins, lambda sim: (getattr(sim, eff_attr),) )

    #
    # Fit efficiency to loglogistic curve
    #
    cents = eff.centres()[0]
    for c, ef, er in zip(eff.centres()[0], eff.array, err.array):
        if opts.verbose:
            print "%2.1g %2.1f +/- %0.2f" % (c, ef, er)
    try:
        attr_scale = numpy.average(map(lambda sim: getattr(sim, eff_attr), inj))
        a, b = get_eff_params( cents, eff.array, attr_scale )
    except RuntimeError:
        print >>sys.stderr, "Warning, couldn't fit efficiency, check bins. Skipping..."
        continue
    ten, fifty, ninety = determine_10_50_90_eff( (a,b) )
    if opts.verbose:
        print "10%% / 50%% / 90%% efficiency estimate: %2.2g %2.2g %2.2g" % (ten, fifty, ninety)

    if opts.squeeze_attribute is not None:
        #new_range = ( max(0, ninety - (ninety-ten)), ten + (ninety-ten) )
        new_range = ( ten, ninety )
        orig_range = []
        for sim in (sim_insp or []) + (sim_burst or []):
            if sim.waveform != wave:
                continue
            try:
                orig_range.append( getattr( sim, opts.squeeze_attribute ) )
            except:
                # TODO: Warn?
                pass

        orig_range = sorted( orig_range )
        orig_range = (orig_range[0], orig_range[-1])
        if opts.verbose:
            print "Squeeze distribution %s according to efficiency values. Will map %s -> %s" % (opts.squeeze_attribute, str(orig_range), str(new_range))

        for sim in sim_burst:
            if sim.waveform != wave:
                continue
            try:
                setattr( sim, opts.squeeze_attribute, rescale_number( getattr( sim, opts.squeeze_attribute ), orig_range, new_range ) )
            except AttributeError:
                # TODO: Warn?
                pass
        for sim in sim_insp:
            if sim.waveform != wave:
                continue
            try:
                setattr( sim, opts.squeeze_attribute, rescale_number( getattr( sim, opts.squeeze_attribute ), orig_range, new_range ) )
            except AttributeError:
                # TODO: Warn?
                pass

utils.write_filename( xmldoc, "squeeze_tmp.xml" )
