#!/usr/bin/env python
from __future__ import division
import os
import numpy as np
import argparse as ap

try:
    import matplotlib.pyplot as plt
    import matplotlib
    matplotlib.use('Agg')
except Exception:
    pass

from bajes.obs.gw.noise import Noise
from bajes.obs.gw.waveform import Waveform, PolarizationTuple
from bajes.obs.gw.strain import Series, fft, ifft
from bajes.obs.gw.detector import Detector, calc_project_array
from bajes.obs.gw.utils import read_asd, read_params
from bajes.pipe import set_logger, ensure_dir

def write_txt_file(name, head, *args):

    ncols   = len(args)
    nraws   = len(np.transpose(args))
    file    = open(name, 'w')
    file.write(head)

    for i in range(nraws):
        for j in range(ncols):
            file.write("{}\t".format(args[j][i]))
        file.write("\n")

    file.close()

def make_spectrogram_plot(ifo, time, inj_strain, noise, f_min, outdir):

    dt      = np.median(np.diff(time))
    srate   = 1./dt
    fNyq    = srate/2.
    freq , inj_hfft = fft(inj_strain,dt)
    seglen  = 1./np.median(np.diff(freq))
    asd     = noise.interp_asd_pad(freq)
    inj_hfft_whit = inj_hfft/(asd*np.sqrt(fNyq))

    time_whit, inj_strain_whit = ifft(inj_hfft_whit , srate=srate, seglen=seglen)

    Nfft    = int (fNyq)//2
    Novl    = int (Nfft * 0.99)
    window  = np.blackman(Nfft)

    try:
        fig = plt.figure(figsize=(12,9))
        plt.title("{} spectrogram".format(ifo), size = 14)
        spec, freqs, bins, im = plt.specgram(inj_strain_whit, NFFT=int(Nfft), Fs=int(opts.srate), noverlap=int(Novl),
                                             window=window, cmap='PuBu', xextent=[0,seglen])
        plt.yscale('log')
        plt.ylim((f_min,0.5/dt))
        plt.xlabel("time [s]")
        plt.ylabel("frequency [Hz]")
        plt.savefig(outdir + '/{}_spectrogram.png'.format(ifo), dpi=150, bbox_inches='tight')
        plt.close()
    except Exception:
        pass


def make_injection_plot(ifo, time, inj_strain, wave_strain, noise, f_min, outdir, alpha=0.4):

    try:
        fig = plt.figure(figsize=(12,9))
        ax1 = fig.add_subplot(211)
        ax2 = fig.add_subplot(212)

        # plot injected strain (black) and signal (red)
        ax1.set_title("{} injection".format(ifo), size = 14)
        ax1.plot(time , inj_strain, c='gray', lw=0.7, label='Injected strain')
        ax1.plot(time , wave_strain, c='slateblue', label='Projected wave')
        ax1.legend(loc='best')
        ax1.set_xlabel('time [s]')
        ax1.set_ylabel('strain')

        # plot central 1s of signal
        mask_ax2 = np.where((time>=np.median(time)-0.5)&(time<=np.median(time)+0.5))
        ax2.plot(time[mask_ax2], wave_strain[mask_ax2], c='r')

        plt.savefig(outdir + '/{}_strains.png'.format(ifo), dpi=100, bbox_inches='tight')
        plt.close()
    except Exception:
        pass


    from scipy.signal import tukey

    dt                   = np.median(np.diff(time))
    seglen               = time[-1]-time[0]
    freq_proj, hfft_proj = fft(wave_strain, dt)
    freq_inj, hfft_inj   = fft(inj_strain*tukey(len(inj_strain),alpha=alpha), dt)

    try:
        fig = plt.figure(figsize=(12,9))
        ax1 = fig.add_subplot(111)

        # plot injected strain (black) and signal (red)
        ax1.set_title("{} spectra".format(ifo), size = 14)
        ax1.loglog(freq_inj , np.abs(hfft_inj), c='gray', lw=0.7, label='Injected strain')
        ax1.loglog(freq_proj , np.abs(hfft_proj), c='royalblue', label='Projected wave')
        ax1.loglog(freq_inj , noise.interp_asd_pad(freq_inj), c='navy', label='ASD')
        ax1.legend(loc='best')

        ax1.set_xlim((f_min,1./dt/2.))
        ax1.set_xlabel('frequency [Hz]')
        ax1.set_ylabel('amplitude spectrum')

        plt.savefig(outdir + '/{}_spectra.png'.format(ifo), dpi=100, bbox_inches='tight')
        plt.close()
    except Exception:
        pass

class Injection(object):

    def __init__(self, ifos, dets, noises, data_path, seglen, srate, f_min, f_max, ra, dec, psi, t_gps, wind_flag, zero_flag, fdinj_flag, tukey=None, t_peak=None):

        # get bajes objects
        self.ifos   = ifos
        self.dets   = dets
        self.noises = noises

        # check
        if len(self.ifos) != len(self.dets.keys()):
            raise ValueError("")
        elif len(self.ifos) != len(self.noises.keys()):
            raise ValueError("")

        # get data properties
        self.seglen = seglen
        self.srate  = srate
        self.f_min  = f_min
        self.ra     = ra
        self.dec    = dec
        self.psi    = psi
        self.t_gps  = t_gps

        if f_max is None:   self.f_max  = self.srate/2.
        else:               self.f_max  = f_max

        if tukey is None:   self.tukey  = 0.4/self.seglen
        else:               self.tukey  = tukey

        # initialize output dictionaries
        self.snrs           = {}
        self.wave_strains   = {}
        self.noise_strains  = {}
        self.inj_strains    = {}
        self.times          = {}
        self.freqs          = {}

        self.fdinj_flag     = fdinj_flag

        Npt                 = int(self.seglen*self.srate)

        ###
        ### Generate noise segments
        ###
        if zero_flag:
            logger.info("Generating zero-noise injection ...")
            for ifo in self.ifos:
                self.noise_strains[ifo] = np.zeros(Npt)
        else:
            logger.info("Generating Gaussian-noise injection ...")
            for ifo in self.ifos:
                self.noise_strains[ifo] = noises[ifo].generate_fake_noise(self.seglen, self.srate, self.t_gps, filter=True)

        if fdinj_flag:
            for ifo in self.ifos:
                fr, nfft                = fft(self.noise_strains[ifo], dt=1./self.srate)
                self.noise_strains[ifo] = nfft

        ###
        ### Generate signal segment
        ###
        if not data_path:

            logger.info("No data file was passed, generating pure noise, i.e. h(t)=0 ...")

            for ifo in self.ifos:
                if fdinj_flag:
                    self.wave_strains[ifo]  = np.zeros(Npt//2+1)
                else:
                    self.wave_strains[ifo]  = np.zeros(Npt)

        else:

            self.data_path  = os.path.abspath(data_path)
            tag             = self.data_path.split('.')[-1]

            ###
            ### Generate signal segment from ASCII file
            ###
            if (tag == 'txt') or (tag == 'dat'):

                logger.info("Reading polarizations from ascii file ...")
                data_table  = np.genfromtxt(self.data_path)
                data_tablet = np.transpose(data_table)
                # times, hp, hc = np.genfromtxt(self.data_path, usecols=[0,1,2], unpack=True)

                # if data-table has 3 columns, we assume TD data in the form [t, hp, hc]
                if len(data_tablet) == 3:
                    td_input      = True
                    times, hp, hc = data_tablet
                # if data-table has 5 columns, we assume FD data in the form [f, Re(hp), Im(hp), Re(hc), Im(hc)]
                elif len(data_tablet) == 5:
                    td_input      = False
                    freqs, rhp, ihp, rhc, ihc = data_tablet
                    hp = rhp + 1j*ihp
                    hc = rhc + 1j*ihc
                else:
                    logger.error("Provided data file has inconsistent number of columns. Please use 3 columns for time-domain data or 5 columns for frequency-domain data.")
                    raise RuntimeError("Provided data file has inconsistent number of columns. Please use 3 columns for time-domain data or 5 columns for frequency-domain data.")

                # apply window according to input
                if td_input:

                    from scipy.signal import tukey
                    if wind_flag == 'low':
                        window  = tukey(len(hp), alpha=self.tukey)
                        imin    = np.max(np.where(window==1))
                        for i in range(len(window)):
                            if i >= imin:
                                window[i] = 1.
                    elif wind_flag == 'both':
                        window  = tukey(len(hp), alpha=self.tukey)
                    elif wind_flag == 'none':
                        window  = np.ones(len(hp))
                    else:
                        logger.error("Invalid window flag passed to the injection. Please use 'low', 'both' or 'none'.")
                        raise ValueError("Invalid window flag passed to the injection. Please use 'low', 'both' or 'none'.")

                    hp              = hp * window
                    hc              = hc * window

                # create series in order to cut the strain
                # obs. avoid window since it is already applied
                if td_input:    series_type = 'time'
                else:           series_type = 'freq'
                ser_hp  = Series(series_type, hp,   srate=self.srate,   seglen=self.seglen,
                                 f_min=self.f_min,  f_max=self.f_max,   importfreqs=freqs,
                                 t_gps=self.t_gps,  alpha_taper=0.0)
                ser_hc  = Series(series_type, hc,   srate=self.srate,   seglen=self.seglen,
                                 f_min=self.f_min,  f_max=self.f_max,   importfreqs=freqs,
                                 t_gps=self.t_gps,  alpha_taper=0.0)

                # # estimate actual GPS time of merger
                # if td_input:
                #     amp     = np.abs(hp - 1j*hc)
                #     dt_mrg  = np.argmax(amp)*ser_hp.dt
                #     t_gps_mrg = self.t_gps - self.seglen/2. + dt_mrg
                # else:
                #     t_gps_mrg = self.t_gps

                if t_peak is None:
                    logger.warning("Requested injection does not specify time_shift of the signal, setting time_shift = 0")
                    t_peak  = 0

                for ifo in self.ifos:
                    if td_input:    target_hp, target_hc, ax = ser_hp.time_series, ser_hc.time_series, times
                    else:           target_hp, target_hc, ax = ser_hp.freq_series, ser_hc.freq_series, freqs
                    self.wave_strains[ifo]  = calc_project_array(self.dets[ifo],
                                                                 target_hp,     target_hc,  1/self.srate,
                                                                 self.ra,       self.dec,   self.psi,
                                                                 self.t_gps,    t_peak,
                                                                 domain=series_type, ax=ax)

                for ifo in self.ifos:

                    if (fdinj_flag and td_input):
                        logger.warning("Requested FD injection but TD data are provided. Computing counterpart with FFT.")
                        fr, wfft                = fft(self.wave_strains[ifo], dt=1./self.srate)
                        self.wave_strains[ifo]  = wfft

                    elif (not(fdinj_flag) and not(td_input)):
                        logger.warning("Requested TD injection but FD data are provided. Computing counterpart with IFFT.")
                        f_out       = np.linspace(0., self.srate/2., int(self.seglen*self.srate//2)+1)
                        data        = self.wave_strains[ifo]
                        amp_out     = np.interp(f_out, freqs, np.abs(data),              left=0., right=0.)
                        phi_out     = np.interp(f_out, freqs, np.unwrap(np.angle(data)), left=0., right=0.)
                        data        = amp_out*np.exp(1j*phi_out)
                        ts, wfft    = ifft(data, srate=self.srate, seglen=self.seglen, t0=self.t_gps)
                        self.wave_strains[ifo] = wfft

            ###
            ### Generate signal segment from config file through Waveform generator
            ###
            elif tag == 'ini':

                logger.info("Generating polarizations from config file ...")

                params = read_params(self.data_path, 'injection')

                if 'approx' not in params.keys():
                    raise AttributeError("Unable to generate the waveform for injection. Approximant field is missing.")

                # check skylocation,
                # overwrite command-line input if skyloc is in params.ini
                if 'ra' in list(params.keys()):
                    logger.info("Overriding the right ascension value found in the parameter file.")
                    params['ra']    = self.ra
                if 'dec' in list(params.keys()):
                    logger.info("Overriding the declination value found in the parameter file.")
                    params['dec']   = self.dec
                if 'psi' in list(params.keys()):
                    logger.info("Overriding the polarization value found in the parameter file.")
                    params['psi']   = self.psi

                # fix missing information
                params['f_min']     = self.f_min
                params['f_max']     = self.f_max
                params['seglen']    = self.seglen
                params['srate']     = self.srate
                params['t_gps']     = self.t_gps
                params['tukey']     = self.tukey

                # printings
                params_keys = list(params.keys())

                if 'approx' not in params_keys:
                    logger.error("Unspecified approximant model for gravitational-wave injection")
                    raise RuntimeError("Unspecified approximant model for gravitational-wave injection")

                if ('mchirp' not in params_keys) and ('mtot' not in params_keys):
                    logger.error("Unspecified total mass / chirp mass parameter for gravitational-wave injection")
                    raise RuntimeError("Unspecified total mass / chirp mass parameter for gravitational-wave injection")

                if 'q' not in params_keys:
                    logger.error("Unspecified mass ratio parameter for gravitational-wave injection")
                    raise RuntimeError("Unspecified mass ratio parameter for gravitational-wave injection")

                logger.info("Generating injection with the paramters:")
                for key in params_keys:
                    logger.info("{}: {}".format(key, params[key]))

                _f              = np.linspace(0,self.srate/2,int(self.seglen*self.srate)//2 +1)
                _f_mask         = np.where((_f>=self.f_min)&(_f<=self.f_max))
                wave            = Waveform(_f[_f_mask],  self.srate, self.seglen, params['approx'])
                signal_template = wave.compute_hphc(params)
                hp              = signal_template.plus
                hc              = signal_template.cross
                series          = Series(wave.domain,   hp,
                                         srate          = self.srate,
                                         seglen         = self.seglen,
                                         f_min          = self.f_min,
                                         f_max          = self.f_max,
                                         t_gps          = self.t_gps,
                                         importfreqs    = _f[_f_mask],
                                         alpha_taper    = 0.0)

                for ifo in self.ifos:

                    # initialize detector with empty data
                    self.dets[ifo].store_measurement(series, noises[ifo])
                    # compute the data
                    if fdinj_flag:
                        self.freqs[ifo] = np.linspace(0,self.srate/2.,int(self.srate*self.seglen//2)+1)
                        _wave   = self.dets[ifo].project_fdwave(signal_template, params, wave.domain)
                        _amp    = np.abs(_wave)
                        _phi    = np.unwrap(np.angle(_wave))
                        _amp    = np.interp(self.freqs[ifo], self.dets[ifo].freqs, _amp, right=0., left=0.)
                        _phi    = np.interp(self.freqs[ifo], self.dets[ifo].freqs, _phi, right=0., left=0.)
                        self.wave_strains[ifo] = _amp*np.exp(1j*_phi)
                    else:
                        self.wave_strains[ifo]  = self.dets[ifo].project_tdwave(signal_template, params, wave.domain)

            else:

                logger.error("Impossible to generate injection from {} file. Use txt/dat or ini.".format(tag))
                ValueError("Impossible to generate injection from {} file. Use txt/dat or ini.".format(tag))

        ###
        ### Compute final injected data and SNR
        ###
        for ifo in self.ifos:

            self.times[ifo]         = np.arange(Npt,dtype=float)/self.srate - self.seglen/2 + self.t_gps
            self.freqs[ifo]         = np.linspace(0,self.srate/2.,int(self.srate*self.seglen//2)+1)
            self.inj_strains[ifo]   = self.noise_strains[ifo] + self.wave_strains[ifo]

            # compute SNR
            if fdinj_flag:
                _f      = self.freqs[ifo]
                _w      = self.wave_strains[ifo]
                _d      = self.inj_strains[ifo]
            else:
                _f, _w  = fft(self.wave_strains[ifo], 1./self.srate)
                _f, _d  = fft(self.inj_strains[ifo], 1./self.srate)

            _i              = np.where(_f>=self.f_min)
            psd             = noises[ifo].interp_psd_pad(_f[_i])
            d_inner_h       = (4/self.seglen)*np.real(np.sum(np.conj(_w[_i])*_d[_i]/psd))
            h_inner_h       = (4/self.seglen)*np.real(np.sum(np.conj(_w[_i])*_w[_i]/psd))
            self.snrs[ifo]  = d_inner_h/np.sqrt(h_inner_h)
            logger.info("  - SNR in {} = {:.3f} ".format(ifo, self.snrs[ifo]))

        # print network SNR
        if list(self.snrs.keys()):
            net_snr = np.sqrt(sum([ self.snrs[ifo]**2. for ifo in  list(self.snrs.keys()) ]))
            logger.info("  - SNR in the Network = {:.3f} ".format(net_snr))

    def write_injections(self, outdir):

        for ifo in self.ifos:

            if self.fdinj_flag:

                # write signal
                if list(self.wave_strains.keys()):
                    write_txt_file(outdir + '/{}_signal.txt'.format(ifo),
                                   '#freq\t re(strain)\t im(strain)\n',
                                   self.freqs[ifo], np.real(self.wave_strains[ifo]),  np.imag(self.wave_strains[ifo]))

                # write noise
                if list(self.noise_strains.keys()):
                    write_txt_file(outdir + '/{}_noise.txt'.format(ifo),
                                   '#freq\t re(strain)\t im(strain)\n',
                                   self.freqs[ifo], np.real(self.noise_strains[ifo]),  np.imag(self.noise_strains[ifo]))

                # write data
                write_txt_file(outdir + '/{}_INJECTION.txt'.format(ifo),
                               '#freq\t re(strain)\t im(strain)\n',
                               self.freqs[ifo], np.real(self.inj_strains[ifo]),  np.imag(self.inj_strains[ifo]))

            else:

                # write signal
                if list(self.wave_strains.keys()):
                    write_txt_file(outdir + '/{}_signal.txt'.format(ifo),
                                   '#time\t strain\n',
                                   self.times[ifo], self.wave_strains[ifo])

                # write noise
                if list(self.noise_strains.keys()):
                    write_txt_file(outdir + '/{}_noise.txt'.format(ifo),
                                   '#time\t strain\n',
                                   self.times[ifo], self.noise_strains[ifo])

                # write data
                write_txt_file(outdir + '/{}_INJECTION.txt'.format(ifo),
                               '#time\t strain\n',
                               self.times[ifo], self.inj_strains[ifo])

                # plots
                make_injection_plot(ifo, self.times[ifo] , self.inj_strains[ifo], self.wave_strains[ifo], self.noises[ifo], self.f_min, outdir, alpha=self.tukey)
                make_spectrogram_plot(ifo, self.times[ifo], self.inj_strains[ifo], self.noises[ifo], self.f_min, outdir)

def bajes_inject_parser():

    parser=ap.ArgumentParser(prog='bajes_inject', usage='bajes_inject [opts]')
    # detectors and ASDs
    parser.add_argument('--ifo',        dest='ifos',    type=str,  action="append",    help="Single IFO tag. This option needs to be passed separately for every ifo in which the injection is requested. The order must correspond to the one in which the '--asd' commands are passed. Available options: ['H1', 'L1', 'V1', 'K1', 'G1'].")
    parser.add_argument('--asd',        dest='asds',    type=str,  action="append",    help="Single path to ASD file. This option needs to be passed separately for every ifo in which the injection is requested.  The order must correspond to the one in which the '--ifo' commands are passed.")

    # data properties
    parser.add_argument('--wave',       dest='wave',    type=str,  default='',     help='path to strain data to inject, the file should contains 3 cols [t, reh, imh]. If empty, pure noise is generated.')
    parser.add_argument('--srate',      dest='srate',   type=float,   help='sampling rate of the injected waveform [Hz] and it will be the srate of the sampling, please check that everything is consistent')
    parser.add_argument('--seglen',     dest='seglen',  type=float,   help='length of the segment of the injected waveform [sec], if it is not a power of 2, the final segment will be padded')
    parser.add_argument('--f-min',      dest='f_min',   type=float,   default=20.,            help='minimum frequency [Hz], default 20Hz')
    parser.add_argument('--f-max',      dest='f_max',   type=float,   default=None,           help='maximum frequency [Hz], default srate/2')
    parser.add_argument('--t-gps',      dest='t_gps',   type=float,   default=1187008882,     help='GPS time of the series, default 1187008882 (GW170817)')

    # random seed - used for noise generation and random sky location
    parser.add_argument('--seed',       dest='seed',    type=int,             default=None,   help='seed for random number generator')

    # zero-noise flag
    parser.add_argument('--zero-noise', dest='zero',    action="store_true",    default=False,  help='use zero noise')

    # FD-injection flag
    parser.add_argument('--fd-inject', dest='fd_inj',  action="store_true",    default=False,  help='perform injection in frequcy-domain')

    # sky location options and extrinsics
    parser.add_argument('--best-sky-loc',               dest='best_sky_loc',            type=str,  default='',                  help='inject signal with optimal sky location for requessted detector')
    parser.add_argument('--random-sky-loc',             dest='random_sky_loc',          action="store_true",    default=False,  help='inject signal with random sky location')
    parser.add_argument('--ra',         dest='ra',      default=None,       type=float,                       help='right ascencion location of the injected source. Default optimal location for the first IFO.')
    parser.add_argument('--dec',        dest='dec',     default=None,       type=float,                       help='declination location of the injected source. Default optimal location for the first IFO.')
    parser.add_argument('--tmerg',      dest='t_peak',  default=None,       type=float,                       help='for ascii injections only, specify the location of the merger with respect to the GPS time')
    parser.add_argument('--pol',        dest='psi',     default=0.,         type=float,                       help='polarization angle of the injected source. Default: 0.')

    # time-domain window options
    parser.add_argument('--window',     dest='window',  default='low',      type=str,                      help="Location of the window. Available options: ['low', 'both', 'none']. Default: 'low'.")
    parser.add_argument('--tukey',      dest='tukey',   default=None,       type=float,                       help='tukey window parameter')

    # output directory
    parser.add_argument('-o','--outdir', dest='outdir', default=None,       type=str,                      help='Output directory. Default: None.')

    return parser.parse_args()

if __name__ == "__main__":

    opts    = bajes_inject_parser()
    dets    = {}
    noises  = {}

    # set outdir
    if opts.outdir == None:
        raise ValueError("Passing an output directory is mandatory. Please pass a value through the '--outdir' option.")

    opts.outdir = os.path.abspath(opts.outdir)
    ensure_dir(opts.outdir)

    # set logger
    global logger
    logger = set_logger(outdir=opts.outdir, label='bajes_inject')
    logger.info("Running bajes inject:")

    # set random seed
    if opts.seed is not None:
        np.random.seed(opts.seed)

    # check provided inputs
    if (len(opts.ifos) != len(opts.asds)):
        logger.error("Number of requested detectors does not match number of given ASDs. Aborting.")
        raise RuntimeError("Number of requested detectors does not match number of given ASDs. Aborting.")

    # get detector and noise objects
    for i in range(len(opts.ifos)):
        ifo = opts.ifos[i]
        logger.info("... setting detector {} for injection ...".format(ifo))
        dets[ifo]   = Detector(ifo,opts.t_gps)
        fr,asd      = read_asd(opts.asds[i], ifo)
        noises[ifo] = Noise(fr, asd, f_min = opts.f_min, f_max = opts.srate/2.)

    # check sky location
    if opts.random_sky_loc:
        logger.info("... injecting signal with random sky location ...")
        opts.ra , opts.dec = 2*np.pi*np.random.uniform(0,1), np.arccos(np.random.uniform(-1,1))-np.pi/2
    elif (opts.ra is not None and opts.dec is not None):
        logger.info("... injecting signal with sky location specified by the user ...")
    else:
        if opts.best_sky_loc not in list(dets.keys()): opts.best_sky_loc = opts.ifos[0]
        logger.info("... injecting signal with optimal sky location for {} ...".format(opts.best_sky_loc))
        opts.ra , opts.dec = dets[opts.best_sky_loc].optimal_orientation(opts.t_gps)

    logger.info("  - right ascenscion = {:.3f}".format(opts.ra))
    logger.info("  - declination      = {:.3f}".format(opts.dec))
    logger.info("  - polarization     = {:.3f}".format(opts.psi))

    # injection
    logger.info("... injecting waveform into detectors ...")
    if not opts.wave:
        logger.warning("Waveform option is empty, generating pure noise injection.")
    inj = Injection(opts.ifos, dets, noises, opts.wave, opts.seglen, opts.srate, opts.f_min, opts.f_max,
                    opts.ra, opts.dec, opts.psi, opts.t_gps, opts.window, opts.zero,
                    opts.fd_inj, opts.tukey, opts.t_peak)

    # save output
    logger.info("... writing strain data files ...")
    inj.write_injections(opts.outdir)


    logger.info("... waveform injected.")
