#!/bin/env python
""" Create a file containing the phase and amplitude, 
correlations between two detectors by
doing a simple monte-carlo
"""

import argparse, h5py, numpy.random, pycbc.detector, logging, multiprocessing
from numpy.random import normal, uniform, power
parser = argparse.ArgumentParser()
parser.add_argument('--ifos', nargs=2, help="The two ifos to generate a histogram for")
parser.add_argument('--sample-size', type=int, 
                    help="Approximant number of independent samples to draw for the distribution")
parser.add_argument('--snr-threshold', type=float, 
                    help="SNR cutoff to apply to the drawn SNR values.")
parser.add_argument('--phase-bins', type=int, default=100,
                    help="Number of bins to distribute phase differences into")
parser.add_argument('--sample-rate', type=int, default=4096, 
                    help="Sample rate used to determine the bin size in the time difference")
parser.add_argument('--max-combined-snr', type=float, default=100,
                    help="Maximum combine newsnr to draw signals out to")
parser.add_argument('--amplitude-ratio-bins', type=int, default=50,
                    help="Number of bins to use for storing amplitude ratios")
parser.add_argument('--amplitude-magnitude-bins', type=int, default=50,
                    help="Number of bins to use for storing the combined SNR magnitudes")
parser.add_argument('--fixed-timing-error', default=1.0/4096, type=float,
                    help="Fixed value for a timing error")
parser.add_argument('--coinc-threshold', default=0, type=float,
                    help="seconds to add to the TOF coinc window")
parser.add_argument('--seed', type=int, default=124)
parser.add_argument('--output-file')
parser.add_argument('--cores', default=1, type=int)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
numpy.random.seed(args.seed)

d1 = pycbc.detector.Detector(str(args.ifos[0]))
d2 = pycbc.detector.Detector(str(args.ifos[1]))
maxdt = d1.light_travel_time_to_detector(d2) + args.coinc_threshold

# Calculate the edges of the bins
max_csnr = args.max_combined_snr
max_ratio = numpy.log((max_csnr**2.0- args.snr_threshold**2.0) ** 0.5 / args.snr_threshold)
tbins = numpy.linspace(- maxdt, maxdt, num=int(args.sample_rate * 2 * maxdt))
pbins = numpy.linspace(0, 2.0 * numpy.pi, num=args.phase_bins + 1)
rbins = numpy.linspace(-max_ratio, max_ratio, num=args.amplitude_ratio_bins + 1)
mbins = numpy.logspace(numpy.log(args.snr_threshold * 2 ** 0.5), numpy.log(max_csnr),
                       num=args.amplitude_magnitude_bins + 1, base=numpy.e)
hist_bins = (tbins, pbins, rbins, mbins)
print tbins
print pbins
print rbins
print mbins
pycbc.init_logging(args.verbose)

def generate_hist(size):
    total = 0
    data = None
    chunksize = 20000 if 20000 < size else size
    while total < size:
        total += chunksize
        chunk_hist, _ = numpy.histogramdd(generate_samples(chunksize), bins=hist_bins)
        if data is None:
            data = chunk_hist
        else:
            data += chunk_hist
    return data

def generate_samples(size):
    logging.info('generating %s samples' % size)
    # Choose random sky location and polarizations
    ra = uniform(0, 2 * numpy.pi, size=size)
    dec = numpy.arccos(uniform(-1., 1., size=size)) - numpy.pi/2
    inc = numpy.arccos(uniform(-1., 1., size=size))
    pol = uniform(0, 2 * numpy.pi, size=size)
    ip = numpy.cos(inc)
    ic = 0.5 * (1.0 + ip * ip)

    # Calculate the expected time offset, and fp,fc for both detectors
    fp1, fc1, fp2, fc2, td = [], [], [], [], []
    for r, d, p in zip(ra, dec, pol):
        r1, r2 = d1.antenna_pattern(r, d, p, 0)
        fp1.append(r1)
        fc1.append(r2)
        r1, r2 = d2.antenna_pattern(r, d, p, 0)
        fp2.append(r1)
        fc2.append(r2)
        
        t1 = d1.time_delay_from_earth_center(r, d, 0)
        t2 = d2.time_delay_from_earth_center(r, d, 0)
        td.append(t1 - t2)

    # Scale fp fc to a volumentric distribution of SNRs
    # add on gaussian errors in SNR
    f = 1000
    fsize = f * size
    dist = power(3, fsize) / args.snr_threshold
    
    sp1 = numpy.resize(fp1 * ip, len(dist)) / dist + numpy.resize(normal(0, scale=1.0, size=size), fsize)
    sp2 = numpy.resize(fp2 * ip, len(dist)) / dist + numpy.resize(normal(0, scale=1.0, size=size), fsize)
    sc1 = numpy.resize(fc1 * ic, len(dist)) / dist + numpy.resize(normal(0, scale=1.0, size=size), fsize)
    sc2 = numpy.resize(fc2 * ic, len(dist)) / dist + numpy.resize(normal(0, scale=1.0, size=size), fsize)
    td = numpy.resize(td, fsize) + normal(0, scale=args.fixed_timing_error, size=fsize)

    # Remove points below the SNR threshold
    t = sp1**2.0 + sc1**2.0 > args.snr_threshold ** 2.0
    t2 = sp2**2.0 + sc2**2.0 > args.snr_threshold ** 2.0
    t = numpy.logical_and(t, t2)
    sp1 = sp1[t]
    sp2 = sp2[t]
    sc1 = sc1[t]
    sc2 = sc2[t]

    s1 = (sp1**2.0 + sc1**2.0)**0.5
    s2 = (sp2**2.0 + sc2**2.0)**0.5

    td = td[t]
    phase_diff = (numpy.arctan2(sc1, sp1) - numpy.arctan2(sc2, sp2)) % (numpy.pi * 2)
    ratio = numpy.log(s1 / s2)
    mag = (s1**2.0 + s2**2.0)**0.5
    
    logging.info('keeping %s values' % len(ratio))
    return td, phase_diff, ratio, mag

# This just makes sure that 3 chunks are submitted to each process
fiddle = 3
core_size = int(args.sample_size / args.cores) / fiddle
chunk_data = [core_size] * args.cores * fiddle

if args.cores == 1:
    h = map(generate_hist, chunk_data)
else:
    pool = multiprocessing.Pool(args.cores)
    h = pool.map(generate_hist, chunk_data)

h = numpy.sum(h, axis=0)
f = h5py.File(args.output_file, 'w')

f['map'] = h
f['tbins'] = tbins
f['pbins'] = pbins
f['rbins'] = rbins
f['mbins'] = mbins
f.attrs['ifo0'] = args.ifos[0]
f.attrs['ifo1'] = args.ifos[1]
f.attrs['stat'] = "phasetd_newsnr"

