#!/bin/env python
""" Create a file containing the phase and amplitude, 
correlations between two detectors by
doing a simple monte-carlo

This ignores the SNR ration and SNR magnitude at the moment and only takes
into account the integrated phase and time differences.
"""

import argparse, h5py, numpy.random, pycbc.detector
from numpy.random import normal, uniform

parser = argparse.ArgumentParser()
parser.add_argument('--ifos', nargs=2)
parser.add_argument('--sample-size', type=int)
parser.add_argument('--snr-threshold', type=float)
parser.add_argument('--coinc-threshold', type=float)
parser.add_argument('--sample-rate', type=int, default=4096)
parser.add_argument('--output-file')
args = parser.parse_args()
numpy.random.seed(seed=124)
d1 = pycbc.detector.Detector(str(args.ifos[0]))
d2 = pycbc.detector.Detector(str(args.ifos[1]))

def generate_samples(size):
    # Choose random sky location and polarizations
    ra = uniform(0, 2 * numpy.pi, size=size)
    dec = numpy.arccos(uniform(-1., 1., size=size)) - numpy.pi/2
    pol = uniform(0, 2 * numpy.pi, size=size)
    dist = uniform(0, 1**3.0, size=size) ** (1.0/3.0) / args.snr_threshold

    # Calculate the expected time offset, and fp,fc for both detectors
    fp1, fc1, fp2, fc2, td = [], [], [], [], []
    for r, d, p in zip(ra, dec, pol):
        r1, r2 = d1.antenna_pattern(r, d, p, 0)
        fp1.append(r1)
        fc1.append(r2)
        r1, r2 = d2.antenna_pattern(r, d, p, 0)
        fp2.append(r1)
        fc2.append(r2)
        
        t1 = d1.time_delay_from_earth_center(r, d, 0)
        t2 = d2.time_delay_from_earth_center(r, d, 0)
        td.append(t1 - t2)

    # Scale fp fc to a volumentric distribution of SNRs
    # add on gaussian errors in SNR
    sp1 = numpy.array(fp1) / dist + normal(0, scale=1.0, size=size)
    sp2 = numpy.array(fp2) / dist + normal(0, scale=1.0, size=size)
    sc1 = numpy.array(fc1) / dist + normal(0, scale=1.0, size=size)
    sc2 = numpy.array(fc2) / dist + normal(0, scale=1.0, size=size)
    td = numpy.array(td) + normal(0, scale=1.0 / args.sample_rate, size=size)

    # Remove points below the SNR threshold
    t = sp1**2.0 + sc1**2.0 > args.snr_threshold ** 2.0
    t2 = sp2**2.0 + sc2**2.0 > args.snr_threshold ** 2.0
    t = numpy.logical_and(t, t2)
    sp1 = sp1[t]
    sp2 = sp2[t]
    sc1 = sc1[t]
    sc2 = sc2[t]

    s1 = (sp1**2.0 + sc1**2.0)**0.5
    s2 = (sp2**2.0 + sc2**2.0)**0.5

    td = td[t]
    phase_diff = (numpy.arctan2(sp1, sc1) - numpy.arctan2(sp2, sc2))
    ratio = numpy.log(s1 / s2)
    mag = (s1**2.0 + s2**2.0)**0.5
    return mag, ratio, td, phase_diff

chunk = 10000
num = 0
ms, rs, ts, ps = [], [], [], []
while num < args.sample_size:
    m, r, t, p = generate_samples(chunk)
    ms.append(m)
    rs.append(r)
    ts.append(t)
    ps.append(p)
    num += len(m)
    print num
 
mag = numpy.concatenate(ms)
ratio = numpy.concatenate(rs)
td = numpy.concatenate(ts)
phase_diff = numpy.concatenate(ps)   


tbins = numpy.arange(min(td), max(td) + 1.0/args.sample_rate, 1.0/args.sample_rate)
pbins = numpy.arange(-2.0 * numpy.pi, 2.0 * numpy.pi + 0.05, 4.0 * numpy.pi / 200)
h, d = numpy.histogramdd((td, phase_diff), bins=(tbins, pbins))
f = h5py.File(args.output_file, 'w')

f['map'] = h
f['tbins'] = tbins
f['pbins'] = pbins
f.attrs['ifo0'] = args.ifos[0]
f.attrs['ifo1'] = args.ifos[1]
f.attrs['stat'] = "phasetd_newsnr"

