#!/usr/bin/env python
""" Calculate psd estimates for analysis segments
"""
import logging, argparse, numpy, h5py, itertools, multiprocessing, time
from six.moves import range
import pycbc, pycbc.psd, pycbc.strain, pycbc.events
from pycbc.version import git_verbose_msg as version
from pycbc.fft.fftw import set_measure_level
from ligo.segments import segmentlist
set_measure_level(0)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action="store_true")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
                    help="The low frequency cutoff to use for filtering (Hz)")
parser.add_argument("--analysis-segment-file",  required=True,
                    help="File defining the segments to estimate PSDs over")
parser.add_argument("--segment-name", help="Name of segment list to use")
parser.add_argument("--cores", default=1, type=int)
parser.add_argument("--output-file", required=True)

pycbc.psd.insert_psd_option_group(parser, output=False)
pycbc.strain.insert_strain_option_group(parser, gps_times=False)
pycbc.strain.StrainSegments.insert_segment_option_group(parser)

args = parser.parse_args()
pycbc.init_logging(args.verbose)

pycbc.psd.verify_psd_options(args, parser)
pycbc.strain.StrainSegments.verify_segment_options(args, parser)
        
def grouper(n, iterable):
    args = [iter(iterable)] * n
    return list([e for e in t if e != None] for t in itertools.izip_longest(*args))

def get_psd(input_tuple):
    """ Get the PSDs for the given data chunck. This follows the same rules
    as pycbc_inspiral for determining where to calculate PSDs
    """
    seg = input_tuple[0]
    i = input_tuple[1]
    pycbc.multiprocess_cache_dir()
    
    logging.info('%d: getting strain for %.1f-%.1f (%.1f s)', i, seg[0],
                 seg[1], abs(seg))
    args.gps_start_time = int(seg[0]) + args.pad_data
    args.gps_end_time = int(seg[1]) - args.pad_data

    # This helps when the filesystem is unreliable, and gives extra retries.
    # python has an internal limit of ~100 (it is not infinite)
    try:
        gwstrain = pycbc.strain.from_cli(args, pycbc.DYN_RANGE_FAC)
    except RuntimeError:
        time.sleep(10)
        return get_psd((seg, i))

    logging.info('%d: determining strain segmentation', i)
    strain_segments = pycbc.strain.StrainSegments.from_cli(args, gwstrain)

    flow = args.low_frequency_cutoff
    flen = strain_segments.freq_len
    tlen = strain_segments.time_len
    delta_f = strain_segments.delta_f

    logging.info('%d: calculating psd', i)
    psds_and_times = pycbc.psd.generate_overlapping_psds(args, gwstrain,
                      flen, delta_f, flow, dyn_range_factor=pycbc.DYN_RANGE_FAC)

    lpsd = []
    for start_idx, end_idx, psd in psds_and_times:
        start_time = gwstrain.start_time + start_idx/gwstrain.sample_rate
        end_time = gwstrain.start_time + end_idx/gwstrain.sample_rate
        lpsd.append((psd.numpy(), psd.delta_f, int(start_time), int(end_time)))

    return lpsd

# Determine what times to calculate PSDs for
ifo = args.channel_name[0:2]
segments = pycbc.events.select_segments_by_definer(args.analysis_segment_file, 
                                                   args.segment_name, ifo=ifo)

# get rid of duplicate segments which happen when splitting the bank
segments = segmentlist(frozenset(segments))

# Calculate the PSDs                                                   
logging.info('%d psds to calculate', len(segments))

if len(segments) > 0:
    pool = multiprocessing.Pool(args.cores)

    # KLUDGE! Run the first segment to esure that the weave cache is
    # populated.  This is a short-term fix until the issues with
    # https://github.com/ligo-cbc/pycbc/issues/501
    # can be resolved.  This will wastefully run the first segment
    # twice which could be avoided, but since this is a short-term
    # kludge anyway I'd prefer to be minimally invasive.
    # TODO: Remove this when 501 is resolved.
    # FIXME: https://www.youtube.com/watch?v=DtRNg5uSKQ0
    get_psd( (segments[0], 0) )

    psds = pool.map_async(get_psd, zip(segments, range(len(segments))))
    psds = psds.get()
else:
    psds = []

# Store the PSDs in an hdf file, include some basic metadata
f = h5py.File(args.output_file, 'w')
psd_group = f.create_group(ifo + '/psds')
inc, start, end = 0, [], []
for gpsd in psds:
    for psd_numpy, psd_delta_f, s, e in gpsd:
        logging.info('writing psd %d', inc)
        key = str(inc)
        start.append(int(s))
        end.append(int(e))
        psd_group.create_dataset(key, data=psd_numpy, compression='gzip',
                                 compression_opts=9, shuffle=True)
        psd_group[key].attrs['epoch'] = int(s)
        psd_group[key].attrs['delta_f'] = float(psd_delta_f)
        inc += 1

f[ifo + '/start_time'] = numpy.array(start, dtype=numpy.uint32)
f[ifo + '/end_time'] = numpy.array(end, dtype=numpy.uint32)
f.attrs['low_frequency_cutoff'] = args.low_frequency_cutoff
f.attrs['dynamic_range_factor'] = pycbc.DYN_RANGE_FAC

logging.info('Done!')

