#!/usr/bin/env python
#
# Copyright (C) 2013 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
from optparse import OptionParser, Option, OptionGroup

from scipy import random
import numpy

from glue.ligolw import lsctables
from glue.ligolw import utils
from glue.ligolw import ligolw
from glue.ligolw import ilwd
from glue.segments import segment
from glue.lal import LIGOTimeGPS as GPS
from glue.ligolw.utils import process
from pylal.antenna import response

import lal
import lalburst
import lalsimulation

# FIXME: Move to a module

header = "#\tGravEn_SimID\tSimHrss\tSimEgwR2\tGravEn_Ampl\tInternal_x\tInternal_phi\tExternal_x\tExternal_phi\tExternal_psi\tFrameGPS\tEarthCtrGPS\tSimName\tSimHpHp\tSimHcHc\tSimHpHc\tGEO\tGEOctrGPS\tGEOfPlus\tGEOfCross\tGEOtErr\tGEOcErr\tGEOsnr\tH1\tH1ctrGPS\tH1fPlus\tH1fCross\tH1tErr\tH1cErr\tH1snr\tH2\tH2ctrGPS\tH2fPlus\tH2fCross\tH2tErr\tH2cErr\tH2snr\tL1\tL1ctrGPS\tL1fPlus\tL1fCross\tL1tErr\tL1cErr\tL1snr\tV1\tV1ctrGPS\tV1fPlus\tV1fCross\tV1tErr\tV1cErr\tV1snr"

def write_burst_mdc_row( row, start=0 ):
    """
    Fill in a template row of a BurstMDC style (GravEn) log.
    
    Template:
    template = simulation_id, hrss, egw, graven_amp, internal_x, internal_phi, external_x, external_phi, external_psi, frame_gps, earth_center_gps, waveform, simhphp, simhxhx, simhphx, g1fp, g1fx, 0, 0, 0, h1fp, h1fx, 0, 0, 0, h2fp, h2fx, 0, 0, 0, l1fp, l1fx, 0, 0, 0, v1fp, v1fx, 0, 0, 0
    """
    sim_id = row.simulation_id
    sim_hrss = row
    # FIXME: What are these?
    graven_amp = 0
    internal_x = 0
    internal_phi = 0
    ###########
    external_x = row.dec
    external_phi = row.ra
    external_psi = row.psi
    frame_gps = start
    earth_center_gps = float(row.get_time_geocent())
    sim_name = row.waveform
    # The factor of 1.8597e-21 is to convert from units of M_s/pc^2 to SI units
    egw = (row.egw_over_rsquared or 0)*1.8597e-21
    # FIXME: This is wrong by the light travel time from earth center to det
    g1fp, g1fx, _, _ = response(earth_center_gps, row.ra, row.dec, 0, row.psi, 'radians', "G1")
    h1fp, h1fx, _, _ = response(earth_center_gps, row.ra, row.dec, 0, row.psi, 'radians', "H1")
    h2fp, h2fx, _, _ = response(earth_center_gps, row.ra, row.dec, 0, row.psi, 'radians', "H2")
    l1fp, l1fx, _, _ = response(earth_center_gps, row.ra, row.dec, 0, row.psi, 'radians', "L1")
    v1fp, v1fx, _, _ = response(earth_center_gps, row.ra, row.dec, 0, row.psi, 'radians', "V1")
    _, simhphp, simhxhx, simhphx = measure_hrss(row)
    hrss = row.hrss or 0
    template = "%d\t%g\t%g\t%f\t%f\t%f\t%f\t%f\t%f\t%10.10f\t%10.10f\t%s\t%g\t%g\t%g\tGEO\t%f\t%f\t%f\t%f\t%f\tH1\t%f\t%f\t%f\t%f\t%f\tH2\t%f\t%f\t%f\t%f\t%f\tL1\t%f\t%f\t%f\t%f\t%f\tV1\t%f\t%f\t%f\t%f\t%f" % (int(row.simulation_id), hrss, egw, graven_amp, internal_x, internal_phi, external_x, external_phi, external_psi, frame_gps, earth_center_gps, row.waveform, simhphp, simhxhx, simhphx, g1fp, g1fx, 0, 0, 0, h1fp, h1fx, 0, 0, 0, h2fp, h2fx, 0, 0, 0, l1fp, l1fx, 0, 0, 0, v1fp, v1fx, 0, 0, 0)
    return template

def write_burst_mdc_log(fname, rows):
    """
    Write out a set of BurstMDC (GravEn) rows into fname.
    """
    f = open(fname, "w")
    print >>f, header
    for row in rows:
        print >>f, row
    f.close()

def measure_egw_rsq(row, rate=16384.0):
    """
    Measure the energy emitted in gravitational waves divided by the distance squared in M_solar / pc^2. This is accomplished by generating the burst and calling the SWIG wrapped  XLALMeasureHrss in lalsimulation. Thus, the row object should be a SWIG wrapped SimBurst object. Rate is the sampling rate in Hz (default is 16kHz).
    """
    swig_row = lalburst.CreateSimBurst()
    for a in lsctables.SimBurstTable.validcolumns.keys():
        try:
            setattr(swig_row, a, getattr( row, a ))
        except AttributeError: continue # we didn't define it
        except TypeError: continue # the structure is different than the TableRow
    hp, hx = lalburst.GenerateSimBurst(swig_row, 1.0/rate)
    return lalsimulation.MeasureEoverRsquared(hp, hx)

def measure_hrss(row, rate=16384.0):
    """
    Measure the various components of hrss (h+^2, hx^2, hphx) for a given input row. This is accomplished by generating the burst and calling the SWIG wrapped  XLALMeasureHrss in lalsimulation. Thus, the row object should be a SWIG wrapped SimBurst object. Rate is the sampling rate in Hz (default is 16kHz).
    """
    swig_row = lalburst.CreateSimBurst()
    for a in lsctables.SimBurstTable.validcolumns.keys():
        try:
            setattr(swig_row, a, getattr( row, a ))
        except AttributeError: continue # we didn't define it
        except TypeError: continue # the structure is different than the TableRow
    hp, hx = lalburst.GenerateSimBurst(swig_row, 1.0/rate)
    # FIXME: Totally inefficent --- but can we deep copy a SWIG SimBurst?
    hp0, hx0 = lalburst.GenerateSimBurst(swig_row, 1.0/rate)
    hp0.data.data *= 0
    hx0.data.data *= 0

    # H+ hrss only
    hphp = lalsimulation.MeasureHrss(hp, hx0)**2
    # Hx hrss only
    hxhx = lalsimulation.MeasureHrss(hp0, hx)**2
    # sqrt(|Hx|^2 + |Hx|^2)
    hrss = lalsimulation.MeasureHrss(hp, hx)

    hp.data.data = numpy.abs(hx.data.data) + numpy.abs(hp.data.data)
    # |H+Hx|
    hphx = (lalsimulation.MeasureHrss(hp, hx0)**2 - hrss**2)/2
    return hrss, hphp, hxhx, hphx

def uniform_dec(num):
    """
    Declination distribution: uniform in sin(dec). num controls the number of draws.
    """
    return (numpy.pi / 2.) - numpy.arccos(2 * random.random_sample(num) - 1)

def uniform_theta(num):
    """
    Uniform in cos distribution. num controls the number of draws.
    """
    return numpy.arccos(2 * random.random_sample(num) - 1)

def uniform_phi(num):
    """
    Uniform in (0, 2pi) distribution. num controls the number of draws.
    """
    return random.random_sample(num) * 2 * numpy.pi

def uniform_interval(interval, num):
    """
    Uniform in an interval, with the interval being a 2-tuple. num controls the number of draws. See also:

    http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.random_sample.html#numpy.random.random_sample
    """
    return (interval[1] - interval[0]) * random.random_sample(num) + interval[0]

# FIXME: Make different injection sources subclass this
class source(object):
    """
    Generic container for different source types. Currently, it checks for the waveform type and initializes itself appropriately. In the future, different sources should subclass this and override the generation routines.
    """
    def __init__(self, waveform, rate, tstart, tstop, **kwargs):
        """
        Initialize a source class. waveform is the name of the waveform family. Rate is the astrophysical rate (if applicable) of the source. tstart and tstop is the segment of time in which to generate the source. Additional arguments are applied on a case-by-case basis depending on the source class.

        SineGaussian: q, frequency, hrss
        StringCusp: frequency, amplitude
        Impulse: hrss
        BTLWNB: duration, bandwidth, frequency, egw/hrss, distance
        adhoc: hrss, filepath
        """
        self.waveform = waveform
        self.wnb_use_hrss = kwargs["wnb_use_hrss"] if kwargs.has_key("wnb_use_hrss") else False
        if waveform == "SineGaussian":
            if not kwargs.has_key("q") or not kwargs.has_key("frequency") or not kwargs.has_key("hrss"):
                raise ValueError("SineGaussian distribution requires a range of q and f.")
            self.q_range = kwargs["q"]
            self.f_range = kwargs["frequency"]
            self.h_range = kwargs["hrss"]
        elif waveform == "StringCusp":
            if not kwargs.has_key("frequency") or not kwargs.has_key("amplitude"):
                raise ValueError("StringCusp distribution requires a range of amplitude and f.")
            self.f_range = kwargs["frequency"]
            self.h_range = kwargs["amplitude"]
        elif waveform == "BTLWNB":
            if not kwargs.has_key("duration") or not kwargs.has_key("bandwidth") or not kwargs.has_key("frequency"):
                raise ValueError("BTLWNB distribution requires a range of frequency, bandwidth, and duration.")
            self.b_range = kwargs["bandwidth"]
            self.d_range = kwargs["duration"]
            self.f_range = kwargs["frequency"]
            if self.wnb_use_hrss:
                self.h_range = kwargs["hrss"]
            else:
                self.egw_range = kwargs["egw"]
        elif waveform == "Impulse":
            if not kwargs.has_key("hrss"):
                raise ValueError("Impulse distribution requires a range of amplitude.")
            self.h_range = kwargs["hrss"]
        elif kwargs.has_key("filepath"):
            if not kwargs.has_key("hrss"):
                raise ValueError("Ad-hoc distribution requires a range of amplitude.")
            self.h_range = kwargs["hrss"]
        else:
            raise ValueError("Unknown waveform family %s." % waveform)

        if kwargs.has_key( "distance" ):
            self.distance = kwargs["distance"]
        self.rate = rate
        self.tstart = tstart
        self.tstop = tstop
        self.expnum = int(numpy.ceil((self.tstop-self.tstart) * self.rate / 365.0 / 24 / 3600 ))

    def log_hrss(self):
        """
        Draw uniformly in the log of a predefined hrss range
        """
        log10h = segment(numpy.log10(self.h_range[0]), numpy.log10(self.h_range[1]))
        return 10**uniform_interval(log10h, self.expnum)

    def log_egw(self):
        """
        Draw uniformly in the log of a predefined E_gw range
        """
        log10e = segment(numpy.log10(self.egw_range[0]), numpy.log10(self.egw_range[1]))
        return 10**uniform_interval(log10e, self.expnum)

    def draw_imp_params(self):
        """
        Draw a set of parameters (hrss) appropriate for a set of impulse type injections.
        """
        h = []
        if type(self.h_range) == segment:
            h = self.log_hrss()
        else:
            h = map(lambda i: self.h_range[i], random.randint(0, len(self.h_range), self.expnum) )

        return h

    def draw_sg_params(self):
        """
        Draw a set of parameters (hrss, q, frequency) appropriate for a set of sinegaussian type injections.
        """
        h, q, f = [], [], []
        if type(self.q_range) == segment:
            q = uniform_interval(self.q_range, self.expnum)
        else:
            q = map(lambda i: self.q_range[i], random.randint(0, len(self.q_range), self.expnum) )

        if type(self.h_range) == segment:
            h = self.log_hrss()
        else:
            h = map(lambda i: self.h_range[i], random.randint(0, len(self.h_range), self.expnum) )

        if type(self.f_range) == segment:
            f = uniform_interval(self.f_range, self.expnum)
        else:
            f = map(lambda i: self.f_range[i], random.randint(0, len(self.f_range), self.expnum))

        return q, f, h

    def draw_cusp_params(self):
        """
        Draw a set of parameters (amplitude, frequency) appropriate for a set of stringcusp type injections.
        """
        a, f = [], []

        if type(self.h_range) == segment:
            a = self.log_hrss()
        else:
            a = map(lambda i: self.h_range[i], random.randint(0, len(self.h_range), self.expnum))

        # Translated from the binj.c source random_string_cusp_fhigh function
        #
        # \theta is the angle the line of sight makes with the cusp.  \theta^{2}
        # is distributed uniformly, and the high frequency cut-off of the
        # injection is \propto \theta^{-3}.  So we first solve for the limits of
        # \theta^{2} from the low- and high bounds of the frequency band of
        # interest, pick a uniformly-distributed number in that range, and 
        # convert back to a high frequency cut-off.
        #
        if type(self.f_range) == segment:
            thetasqmin, thetasqmax = self.f_range[0]**(-2.0/3.0), self.f_range[1]**(-2.0/3.0)
            f = uniform_interval((thetasqmin, thetasqmax), self.expnum)
            f = f**(-3.0/2.0)
        else:
            f = map(lambda i: self.f_range[i], random.randint(0, len(self.f_range), self.expnum))

        return a, f

    def draw_wnb_params(self):
        """
        Draw a set of parameters (bandwidth, duration, frequency, energy) appropriate for a set of BTLWNB type injections.
        """
        b, d, f, e = [], [], [], []

        if type(self.b_range) == segment:
            b = uniform_interval(self.b_range, self.expnum)
        else:
            b = map(lambda i: self.b_range[i], random.randint(0, len(self.b_range), self.expnum) )

        if type(self.d_range) == segment:
            d = uniform_interval(self.d_range, self.expnum)
        else:
            d = map(lambda i: self.d_range[i], random.randint(0, len(self.d_range), self.expnum))

        # FIXME: Make this logarithmic?
        if type(self.f_range) == segment:
            f = uniform_interval(self.f_range, self.expnum)
        else:
            f = map(lambda i: self.f_range[i], random.randint(0, len(self.f_range), self.expnum))

        if self.wnb_use_hrss:
            if type(self.h_range) == segment:
                e = self.log_hrss()
            elif self.h_range is not None:
                e = map(lambda i: self.h_range[i], random.randint(0, len(self.h_range), self.expnum))
        else:
            if type(self.egw_range) == segment:
                e = self.log_egw()
            elif self.egw_range is not None:
                e = map(lambda i: self.egw_range[i], random.randint(0, len(self.egw_range), self.expnum))

        return b, d, f, e

    def uniform_sky(self):
        """
        Get a set of (RA, declination, polarization) randomized appopriately to astrophysical sources isotropically distributed in the sky.
        """
        expnum = self.expnum
        ra = uniform_phi(expnum)
        dec = uniform_dec(expnum)
        pol = uniform_phi(expnum)
        return ra, dec, pol

    def uniform_time(self):
        """
        Get a set of randomized (integer) event times.
        """
        return random.randint(self.tstart, self.tstop, self.expnum) + random.rand(self.expnum)

    def volume_distributed_distances(self):
        """
        Get a set of event distances which is randomized uniformly in the volume.
        """
        return random.power(3, self.expnum) * self.distance # 3 because it is this number minus 1 according to scipy doc


parser = OptionParser(description = __doc__)
general_section = OptionGroup(parser, "General", "This group controls the general behavior of injections, such as their GPS timing and other common variables.")
general_section.add_option("--gps-start-time", metavar = "seconds", type = "int", help = "Set the start time of the segment to analyze in GPS seconds (required).  Can be specified to nanosecond precision.")
general_section.add_option("--gps-end-time", metavar = "seconds", type = "int", help = "Set the end time of the segment to analyze in GPS seconds (required).  Can be specified to nanosecond precision.")
general_section.add_option("--burst-family", metavar="str", action="append", help = "Set the name of the burst waveform family, required, valid choices are SineGaussian, StringCusp, BTLWNB, and Impulse.")
general_section.add_option("--output", metavar = "filename", default="binj.xml.gz", help = "Set the name of the LIGO light-weight XML file to output")
general_section.add_option("--seed", metavar = "int", type = "int", default = 1, help = "set the random seed default = 1")
general_section.add_option("--time-slide-id", metavar= "int", type = "int", default = 0, help = "Use the time slide id, default is 0.")
general_section.add_option("--write-mdc-log", metavar = "str", help = "Write a BurstMDC log to this file compatible with the XML output.")

# FIXME: New subsection?
general_section.add_option("--event-time-type", metavar = "type", default = "random", help = "How to determine the event injection times. Type 'random' means use event rate to generate number of injections and place that number of injections randomly throughout the interval. Type 'fixed' means use the event rate to generate number of injections and sub-interval, then, beginning from the start of the interval add the sub-interval, place injection, and repeat. Default is 'random'.")
general_section.add_option("--event-rate", metavar = "events / yr", type = "float", default = 1, help = "Rate of events in 1/yr, default = 1")
general_section.add_option("--jitter", metavar = "s", type = "float", default = 0.0, help = "Window around event to add a random 'jitter' to the event time. Default is zero. Only applies to event-time-type='fixed'.")
parser.add_option_group(general_section)

adhoc_section = OptionGroup(parser, "File based ad-hoc injection options", "This group controls the distribution of file based 'ad-hoc' injections.")
adhoc_section.add_option("--adhoc-fix-hrss", metavar = "strain", type = "float", action="append", help = "Fix the hrss value of the ad-hoc injections. Give several times with different values to build up an array of hrss to choose from. Overrides min and max options.")
adhoc_section.add_option("--adhoc-min-hrss", metavar = "strain", type = "float", help = "Fix the min hrss value of the ad-hoc injections.")
adhoc_section.add_option("--adhoc-max-hrss", metavar = "strain", type = "float", help = "Fix the max hrss value of the ad-hoc injections.")
adhoc_section.add_option("--adhoc-file-location", metavar = "path", action="append", help = "File paths to ad-hoc injections. Specify as 'family name'='path'.")
parser.add_option_group(adhoc_section)

sg_section = OptionGroup(parser, "SineGaussian options", "This group controls the distribution of SineGaussian injections.")
sg_section.add_option("--sg-fix-hrss", metavar = "strain", type = "float", action="append", help = "Fix the hrss value of the SineGaussian injections. Give several times with different values to build up an array of hrss to choose from. Overrides min and max options.")
sg_section.add_option("--sg-min-hrss", metavar = "strain", type = "float", help = "Fix the min hrss value of the SineGaussian injections.")
sg_section.add_option("--sg-max-hrss", metavar = "strain", type = "float", help = "Fix the max hrss value of the SineGaussian injections.")
sg_section.add_option("--sg-fix-q", metavar = "Q", type = "float", action="append", help = "Fix the Q value of the SineGaussian injections. Give several times with different values to build up an array of Qs to choose from. Overrides min and max options.")
sg_section.add_option("--sg-min-q", metavar = "Q", type = "float", help = "Fix the min Q value of the SineGaussian injections.")
sg_section.add_option("--sg-max-q", metavar = "Q", type = "float", help = "Fix the max Q value of the SineGaussian injections.")
sg_section.add_option("--sg-fix-f", metavar = "Hz", action="append", type = "float", help = "Fix the central frequency value of the SineGaussian injections. Give several times with different values to build up an array of frequencies to choose from. Overrides min and max options.")
sg_section.add_option("--sg-min-f", metavar = "Hz", type = "float", help = "Fix the min frequency value of the SineGaussian injections.")
sg_section.add_option("--sg-max-f", metavar = "Hz", type = "float", help = "Fix the max frequency value of the SineGaussian injections.")
sg_section.add_option("--sg-polarization-type", metavar = "type", default='elliptical', help = "Fix the polarization type of the SineGaussian injections. Valid types are linear, circular, and elliptical. Default is elliptical.")
parser.add_option_group(sg_section)

wnb_section = OptionGroup(parser, "Band and time limited white noise burst (BTLWNB) options", "This group controls the distribution of BTLWNB injections.")
wnb_section.add_option("--wnb-max-dist", metavar = "Mpc", default = 1.0, type = "float", help = "Maximum distance for which to produce WNB in Mpc, default is 1.")
wnb_section.add_option("--wnb-fix-egw", metavar = "M_s", type = "float", action="append", help = "Fix the energy (in solar masses with c=1) value of the BTLWNB injections. Give several times with different values to build up an array of energies to choose from. Overrides min and max options.")
wnb_section.add_option("--wnb-min-egw", metavar = "M_s", type = "float", help = "Fix the min energy value of the BTLWNB injections.")
wnb_section.add_option("--wnb-max-egw", metavar = "M_s", type = "float", help = "Fix the max energy value of the BTLWNB injections.")
wnb_section.add_option("--wnb-fix-hrss", metavar = "strain", type = "float", action="append", help = "Fix the hrss value of the BTLWNB injections. Give several times with different values to build up an array of hrss to choose from. Overrides min and max options.")
wnb_section.add_option("--wnb-min-hrss", metavar = "strain", type = "float", help = "Fix the min hrss value of the BTLWNB injections.")
wnb_section.add_option("--wnb-max-hrss", metavar = "strain", type = "float", help = "Fix the max hrss value of the BTLWNB injections.")
wnb_section.add_option("--wnb-fix-bandwidth", metavar = "Hz", type = "float", action="append", help = "Fix the bandwidth value of the BTLWNB injections. Give several times with different values to build up an array of bandwidths to choose from. Overrides min and max options.")
wnb_section.add_option("--wnb-min-bandwidth", metavar = "Hz", type = "float", help = "Fix the min bandwidth value of the BTLWNB injections.")
wnb_section.add_option("--wnb-max-bandwidth", metavar = "Hz", type = "float", help = "Fix the max bandwidth value of the BTLWNB injections.")
wnb_section.add_option("--wnb-fix-duration", metavar = "s", type = "float", action="append", help = "Fix the duration value of the BTLWNB injections. Give several times with different values to build up an array of durations to choose from. Overrides min and max options.")
wnb_section.add_option("--wnb-min-duration", metavar = "s", type = "float", help = "Fix the min duration value of the BTLWNB injections.")
wnb_section.add_option("--wnb-max-duration", metavar = "s", type = "float", help = "Fix the max duration value of the BTLWNB injections.")
wnb_section.add_option("--wnb-fix-frequency", metavar = "Hz", type = "float", action="append", help = "Fix the central frequency value of the BTLWNB injections. Give several times with different values to build up an array of frequencies to choose from. Overrides min and max options.")
wnb_section.add_option("--wnb-min-frequency", metavar = "Hz", type = "float", help = "Fix the min central frequency value of the BTLWNB injections.")
wnb_section.add_option("--wnb-max-frequency", metavar = "Hz", type = "float", help = "Fix the max central frequency value of the BTLWNB injections.")
parser.add_option_group(wnb_section)

imp_section = OptionGroup(parser, "Impulse options", "This group controls the distribution of Impulse injections.")
imp_section.add_option("--imp-fix-hrss", metavar = "strain", type = "float", action="append", help = "Fix the hrss value of the Impulse injections. Give several times with different values to build up an array of hrss to choose from. Overrides min and max options.")
imp_section.add_option("--imp-min-hrss", metavar = "strain", type = "float", help = "Fix the min hrss value of the Impulse injections.")
imp_section.add_option("--imp-max-hrss", metavar = "strain", type = "float", help = "Fix the max hrss value of the Impulse injections.")
parser.add_option_group(imp_section)

cusp_section = OptionGroup(parser, "StringCusp options", "This group controls the distribution of StringCusp injections.")
cusp_section.add_option("--cusp-fix-frequency", metavar = "Hz", type = "float", action="append", help = "Fix the central frequency value of the BTLWNB injections. Give several times with different values to build up an array of frequencies to choose from. Overrides min and max options.")
cusp_section.add_option("--cusp-min-frequency", metavar = "Hz", type = "float", help = "Fix the min central frequency value of the BTLWNB injections.")
cusp_section.add_option("--cusp-max-frequency", metavar = "Hz", type = "float", help = "Fix the max central frequency value of the BTLWNB injections.")
cusp_section.add_option("--cusp-fix-amplitude", metavar = "amp", type = "float", action="append", help = "Fix the amplitude value of the StringCusp injections. Give several times with different values to build up an array of amplitude to choose from. Overrides min and max options.")
cusp_section.add_option("--cusp-min-amplitude", metavar = "amp", type = "float", help = "Fix the min amplitude value of the StringCusp injections.")
cusp_section.add_option("--cusp-max-amplitude", metavar = "amp", type = "float", help = "Fix the max amplitude value of the StringCusp injections.")
parser.add_option_group(cusp_section)

#
# Parse options
#

options, filenames = parser.parse_args()

#
# Setup the output document
#

xmldoc = ligolw.Document()
lw = xmldoc.appendChild(ligolw.LIGO_LW())
sim = lsctables.New(lsctables.SimBurstTable)
lw.appendChild(sim)
procrow = process.register_to_xmldoc(xmldoc, "ligolw_binj", options.__dict__)

# Setup some global parameters
stop, start = options.gps_end_time, options.gps_start_time
random.seed(options.seed)

if options.adhoc_file_location is None and options.burst_family is None:
    exit("Must specify a burst family or adhoc file locations")

srcs = []

# Record keeping for adhoc file-based injections
adhoc_families = {}

if options.adhoc_file_location:

    adhoc_families = dict([opt.split("=") for opt in options.adhoc_file_location])
    for family, path in adhoc_families.iteritems():

        if options.adhoc_min_hrss is not None and options.adhoc_max_hrss is not None:
            h_values = segment(options.adhoc_min_hrss, options.adhoc_max_hrss)
        else:
            if options.adhoc_fix_hrss is not None:
                h_values = options.adhoc_fix_hrss
            else:
                raise ValueError("Current options do not constrain adhoc hrss values.")
        srcs.append(source(family, options.event_rate, start, stop, hrss=h_values, filepath=path))
    if options.burst_family is None:
        options.burst_family = []


# set up the types of sources
for srcname in options.burst_family:
    if srcname == "SineGaussian":
        if options.sg_min_q is not None and options.sg_max_q is not None:
            q_values = segment(options.sg_min_q, options.sg_max_q)
        else:
            if options.sg_fix_q is not None:
                q_values = options.sg_fix_q
            else:
                raise ValueError("Current options do not constrain SG q values.")

        if options.sg_min_f is not None and options.sg_max_f is not None:
            f_values = segment(options.sg_min_f, options.sg_max_f)
        else:
            if options.sg_fix_f is not None:
                f_values = options.sg_fix_f
            else:
                raise ValueError("Current options do not constrain SG frequency values.")

        if options.sg_min_hrss is not None and options.sg_max_hrss is not None:
            h_values = segment(options.sg_min_hrss, options.sg_max_hrss)
        else:
            if options.sg_fix_hrss is not None:
                h_values = options.sg_fix_hrss
            else:
                raise ValueError("Current options do not constrain SG hrss values.")

        srcs.append(source(srcname, options.event_rate, start, stop, q=q_values, frequency=f_values, hrss=h_values))

    if srcname == "StringCusp":
        if options.cusp_min_frequency is not None and options.cusp_max_frequency is not None:
            f_values = segment(options.cusp_min_frequency, options.cusp_max_frequency)
        else:
            if options.cusp_fix_frequency is not None:
                f_values = options.cusp_fix_frequency
            else:
                raise ValueError("Current options do not constrain cusp frequency values.")

        if options.cusp_min_amplitude is not None and options.cusp_max_amplitude is not None:
            a_values = segment(options.cusp_min_amplitude, options.cusp_max_amplitude)
        else:
            if options.cusp_fix_amplitude is not None:
                a_values = options.cusp_fix_amplitude
            else:
                raise ValueError("Current options do not constrain cusp amplitude values.")

        srcs.append(source(srcname, options.event_rate, start, stop, frequency=f_values, amplitude=a_values))

    elif srcname == "Impulse":
        if options.imp_min_hrss is not None and options.imp_max_hrss is not None:
            h_values = segment(options.imp_min_hrss, options.imp_max_hrss)
        else:
            if options.imp_fix_hrss is not None:
                h_values = options.imp_fix_hrss
            else:
                raise ValueError("Current options do not constrain Impulse hrss values.")

        srcs.append(source(srcname, options.event_rate, start, stop, hrss=h_values))

    elif srcname == "BTLWNB":
        if options.wnb_min_duration is not None and options.wnb_max_duration is not None:
            d_values = segment(options.wnb_min_duration, options.wnb_max_duration)
        else:
            if options.wnb_fix_duration is not None:
                d_values = options.wnb_fix_duration
            else:
                raise ValueError("Current options do not constrain WNB duration values.")

        if options.wnb_min_bandwidth is not None and options.wnb_max_duration is not None:
            b_values = segment(options.wnb_min_bandwidth, options.wnb_max_bandwidth)
        else:
            if options.wnb_fix_bandwidth is not None:
                b_values = options.wnb_fix_bandwidth
            else:
                raise ValueError("Current options do not constrain WNB bandwidth values.")

        if options.wnb_min_frequency is not None and options.wnb_max_duration is not None:
            f_values = segment(options.wnb_min_frequency, options.wnb_max_frequency)
        else:
            if options.wnb_fix_frequency is not None:
                f_values = options.wnb_fix_frequency
            else:
                raise ValueError("Current options do not constrain WNB frequency values.")

        wnb_use_hrss, h_values, e_values = False, None, None
        if options.wnb_min_egw is not None and options.wnb_max_egw is not None:
            e_values = segment(options.wnb_min_egw, options.wnb_max_egw)
        elif options.wnb_min_hrss is not None and options.wnb_max_hrss is not None:
            h_values = segment(options.wnb_min_hrss, options.wnb_max_hrss)
            wnb_use_hrss = True
        else:
            if options.wnb_fix_egw is not None:
                e_values = options.wnb_fix_egw
            elif options.wnb_fix_egw is not None:
                h_values = options.wnb_fix_hrss
                wnb_use_hrss = True
            else:
                raise ValueError("Current options do not constrain WNB E_GW values.")

        srcs.append(source(srcname, options.event_rate, start, stop, distance=options.wnb_max_dist, duration=d_values, bandwidth=b_values, frequency=f_values, egw=e_values, hrss=h_values, wnb_use_hrss = wnb_use_hrss))

for src in srcs:

    # injection parameters
    if options.event_time_type == "random":
        time = src.uniform_time()
    elif options.event_time_type == "fixed":
        interval = (stop - start) / src.expnum
        time = numpy.array(map(lambda i: start + i*interval + options.jitter *(random.rand()-0.5), range(1,src.expnum+1)))

    if src.waveform == "SineGaussian":
        q, f, hrss = src.draw_sg_params()
    elif src.waveform == "StringCusp":
        amp, f = src.draw_cusp_params()
    elif src.waveform == "Impulse":
        hrss = src.draw_imp_params()
    elif src.waveform == "BTLWNB":
        band, dur, f, egw = src.draw_wnb_params()
        dist = src.volume_distributed_distances()
    elif src.waveform in adhoc_families:
        hrss = src.draw_imp_params()
    else:
        raise ValueError("Invalid waveform family indicator.")

    ra, dec, pol = src.uniform_sky()

    for i,t in enumerate(time):
        row = sim.RowType()
        print "Generating %s at %f" % (src.waveform, t)

        # Required columns not defined makes ligolw unhappy
        for a in lsctables.SimBurstTable.validcolumns.keys():
            setattr(row, a, None)

        # string paramters
        row.waveform = src.waveform

        # time parameters
        row.set_time_geocent(GPS(float(t)))

        # location / orientation
        row.ra = ra[i]
        row.dec = dec[i]
        row.psi = pol[i]

        # misc
        row.simulation_id = sim.get_next_id()
        row.waveform_number = random.randint(0,int(2**32)-1)
        row.process_id = procrow.process_id
        row.time_slide_id = ilwd.ilwdchar("time_slide:time_slide_id:%d" % options.time_slide_id)

        if row.waveform == "SineGaussian":
            row.q = q[i]
            row.frequency = f[i]
            if options.sg_polarization_type == "linear":
                row.pol_ellipse_e = 1.0
                row.pol_ellipse_angle = 0
            elif options.sg_polarization_type == "circular":
                row.pol_ellipse_e = 0.0
                row.pol_ellipse_angle = 0
            elif options.sg_polarization_type == "elliptical":
                row.pol_ellipse_e = uniform_interval((0,1),1)[0]
                row.pol_ellipse_angle = uniform_interval((0,2*numpy.pi),1)[0]
            row.hrss = hrss[i]

        elif row.waveform == "StringCusp":
            row.frequency = f[i]
            row.amplitude = amp[i]
            # These are ignored, but for sake of completeness (as in lalapps_binj)
            row.pol_ellipse_e = 1.0
            row.pol_ellipse_angle = 0.0

        elif row.waveform == "Impulse":
            row.hrss = hrss[i]

        elif row.waveform == "BTLWNB":
            row.frequency = f[i]
            while dur[i]**2/4.0 - 1/numpy.pi**2/band[i]**2 < 0:
                # FIXME: Envelope is too small -- this is the
                # wrong way to fix it
                dur*=2

            row.duration = dur[i]
            row.bandwidth = band[i]
            row.pol_ellipse_e = 0.0
            row.pol_ellipse_angle = 0.0
            # dist in Mpc
            # expected units are M_solar / pc^2
            if options.wnb_fix_hrss or options.wnb_min_hrss or options.wmb_max_hrss:
                row.hrss = egw[i]
                # Maggiore (2008), eqn 7.111, adjusted for M_solar / c^2
                row.egw_over_rsquared = row.hrss**2 * numpy.pi**2 * row.frequency**2 * lal.C_SI / lal.G_SI / lal.MSUN_SI * lal.PC_SI**2
                # Not strictly necessary, but the hrss we measure from the
                # generated waveform usually ends up being a bit different than
                # the input
                row.hrss = measure_hrss(row)[0]
            else:
                row.egw_over_rsquared = egw[i] / dist[i]**2 * 1e-12
                row.hrss = measure_hrss(row)[0]

        sim.append(row)

utils.write_filename(xmldoc, options.output, gz=options.output.endswith("gz"))

#
# Write a BurstMDC (GravEn) style ASCII log
#
if options.write_mdc_log:
    mdc_log = []
    for row in sim:
        mdc_log.append(write_burst_mdc_row(row, start))
    write_burst_mdc_log(options.write_mdc_log, mdc_log)
