#!python

# Copyright 2024 Max Trevor
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""
Calculate the noise trigger rate adjustment as a function of data-quality state
for a day of PyCBC Live triggers.
"""

import logging
import argparse
import hashlib

import numpy as np

import pycbc
from pycbc.io import HFile

import gwdatafind

from gwpy.timeseries import TimeSeriesDict
from gwpy.segments import Segment, DataQualityFlag

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--trigger-file", required=True)
parser.add_argument("--ifo", required=True)
parser.add_argument("--gps-start-time", required=True, type=int)
parser.add_argument("--gps-end-time", required=True, type=int)
parser.add_argument("--template-bin-file", required=True)
parser.add_argument("--analysis-frame-type", required=True)
parser.add_argument("--analysis-flag-name", required=True)
parser.add_argument("--dq-channel", required=True)
parser.add_argument("--dq-ok-channel", required=True)
parser.add_argument("--dq-thresh", required=True, type=float)
parser.add_argument("--output", required=True)
pycbc.add_common_pycbc_options(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Get observing segs
ar_flag_name = args.analysis_flag_name.format(ifo=args.ifo)
day_seg = Segment(args.gps_start_time, args.gps_end_time)
observing_flag = DataQualityFlag.query(ar_flag_name, day_seg)
observing_segs = observing_flag.active
livetime = observing_flag.livetime
logging.info(f'Found {livetime} seconds of observing time at {args.ifo}.')

# for each segment, check how much time was dq flagged
flagged_time = 0
frame_type = args.analysis_frame_type.format(ifo=args.ifo)
dq_channel = args.dq_channel.format(ifo=args.ifo)
dq_ok_channel = args.dq_ok_channel.format(ifo=args.ifo)
for seg in observing_segs:
    frames = gwdatafind.find_urls(
        args.ifo[0],
        frame_type,
        seg[0],
        seg[1],
    )
    tsdict = TimeSeriesDict.read(
        frames,
        channels=[dq_channel, dq_ok_channel],
        start=seg[0],
        end=seg[1],
    )
    dq_ok_flag = (tsdict[dq_ok_channel] == 1).to_dqflag()
    dq_flag = (tsdict[dq_channel] <= args.dq_thresh).to_dqflag()
    flagged_time += (dq_flag & dq_ok_flag).livetime
logging.info(f'Found {flagged_time} seconds of dq flagged time at {args.ifo}.')

bg_livetime = livetime - flagged_time
state_time = np.array([bg_livetime, flagged_time])

# read in template bins
logging.info(f'Reading template bins from {args.template_bin_file}')
template_bins = {}
num_templates = 0
with HFile(args.template_bin_file, 'r') as bin_file:
    f_lower = bin_file.attrs['f_lower']
    bank_file = bin_file.attrs['bank_file']
    bin_string = bin_file.attrs['background_bins']

    # only ever one group in this file
    grp = bin_file[list(bin_file.keys())[0]]
    num_bins = len(grp.keys())
    for k in grp.keys():
        template_bins[k] = grp[k]['tids'][:]
        num_templates += len(template_bins[k])

# for each bin, get total number of triggers and number of dq flagged triggers
bin_total_triggers = np.zeros(num_bins)
bin_dq_triggers = np.zeros(num_bins)

logging.info(f'Reading triggers from {args.trigger_file}')
with HFile(args.trigger_file, 'r') as trigf:
    for bin_name in template_bins.keys():
        # bins are named as 'bin{bin_num}'
        bin_num = int(bin_name[3:])
        template_ids = template_bins[bin_name]
        for template_id in template_ids:
            # get dq states for all triggers with this template
            # dq state is either 0 or 1 for each trigger
            dq_key = trigf[f'{args.ifo}/dq_state_template'][template_id]
            dq_states = trigf[f'{args.ifo}/dq_state'][dq_key]

            # update trigger counts
            bin_total_triggers[bin_num] += len(dq_states)
            bin_dq_triggers[bin_num] += np.sum(dq_states)

# write outputs to file
logging.info(f'Writing results to {args.output}')
with HFile(args.output, 'w') as f:
    ifo_group = f.create_group(args.ifo)
    ifo_group.create_dataset('observing_livetime', data=livetime)
    ifo_group.create_dataset('dq_flag_livetime', data=flagged_time)
    bin_group = ifo_group.create_group('bins')
    for bin_name in template_bins.keys():
        bin_num = int(bin_name[3:])
        bgrp = bin_group.create_group(bin_name)
        bgrp.create_dataset('tids', data=template_bins[bin_name])
        bgrp.create_dataset('total_triggers', data=bin_total_triggers[bin_num])
        bgrp.create_dataset('dq_triggers', data=bin_dq_triggers[bin_num])

        bg_triggers = bin_total_triggers[bin_num] - bin_dq_triggers[bin_num]
        num_trigs = np.array([bg_triggers, bin_dq_triggers[bin_num]])
        trig_rates = num_trigs / state_time
        mean_rate = bin_total_triggers[bin_num] / livetime
        normalized_rates = trig_rates / mean_rate
        bgrp.create_dataset('dq_rates', data=normalized_rates)

    f.attrs['dq_thresh'] = args.dq_thresh
    f.attrs['dq_channel'] = dq_channel
    f.attrs['dq_ok_channel'] = dq_ok_channel
    f.attrs['gps_start_time'] = args.gps_start_time
    f.attrs['gps_end_time'] = args.gps_end_time
    f.attrs['f_lower'] = f_lower
    f.attrs['bank_file'] = bank_file
    f.attrs['background_bins'] = bin_string

    # hash is used to check if different files have compatible settings
    settings_to_hash = [args.dq_thresh, dq_channel, dq_ok_channel,
                        f_lower, bank_file, bin_string]
    setting_str = ' '.join([str(s) for s in settings_to_hash])
    hash_object = hashlib.sha256(setting_str.encode())
    f.attrs['settings_hash'] = hash_object.hexdigest()
