#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import numpy as np
import pycbc
from pycbc import bin_utils
from pycbc.events import cuts, triggers, trigger_fits as trstats
from pycbc.io import DictArray
from pycbc.events import ranking
import argparse, logging, os, sys, h5py

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--ifos", nargs="+", required=True,
                    help="Which ifo are we fitting the triggers for? "
                         "Required")
parser.add_argument("--top-directory", metavar='PATH', required=True,
                    help="Directory containing trigger files, top directory, "
                         "will contain subdirectories for each day of data. "
                         "Required.")
parser.add_argument("--analysis-date", required=True,
                    help="Date of the analysis, format YYYY_MM_DD. Required")
parser.add_argument("--file-identifier", default="H1L1V1-Live",
                    help="String required in filename to be considered for "
                         "analysis. Default: 'H1L1V1-Live'.")
parser.add_argument("--fit-function", default="exponential",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit. "
                         "Choose from exponential, rayleigh or power. "
                         "Default: exponential")
parser.add_argument("--duration-bin-edges", nargs='+', type=float,
                    help="Durations to use for bin edges. "
                         "Use if specifying exact bin edges, "
                         "Not compatible with --duration-bin-start, "
                         "--duration-bin-end and --num-duration-bins")
parser.add_argument("--duration-bin-start", type=float,
                    help="Shortest duration to use for duration bins."
                         "Not compatible with --duration-bins, requires "
                         "--duration-bin-end and --num-duration-bins.")
parser.add_argument("--duration-bin-end", type=float,
                    help="Longest duration to use for duration bins.")
parser.add_argument("--num-duration-bins", type=int,
                    help="How many template duration bins to split the bank "
                         "into before fitting.")
parser.add_argument("--duration-bin-spacing",
                    choices=['linear','log'], default='log',
                    help="How to set spacing for bank split "
                         "if using --duration-bin-start, --duration-bin-end "
                         "and --num-duration-bins.")
parser.add_argument("--fit-threshold", type=float, default=5,
                    help="Lower threshold used in fitting the triggers."
                         "Default 5.")
parser.add_argument("--cluster", action='store_true',
                    help="Only use maximum of the --sngl-ranking value "
                         "from each file.")
parser.add_argument("--output", required=True,
                    help="File in which to save the output trigger fit "
                         "parameters.")
parser.add_argument("--sngl-ranking", default="newsnr",
                    choices=ranking.sngls_ranking_function_dict.keys(),
                    help="The single-detector trigger ranking to use.")
#parser.add_argument("--", default="", help="")

cuts.insert_cuts_option_group(parser)

#Add some input sanitisation
args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Check the bin inputs to see if they are sensible
if args.duration_bin_edges and (args.duration_bin_start or
                                args.duration_bin_end or
                                args.num_duration_bins):
    # duration bin edges specified as well as the linear/logarithmi
    parser.error("Cannot use --duration-bin-edges with "
                 "--duration-bin-start, --duration-bin-end or "
                 "--num-duration-bins.")

if not args.duration_bin_edges and not (args.duration_bin_start and
                                        args.duration_bin_end and
                                        args.num_duration_bins):
    parser.error("--duration-bin-start, --duration-bin-end and "
                 "--num-duration-bins must be set if not using "
                 "--duration-bin-edges.")

if args.duration_bin_end and \
        args.duration_bin_end <= args.duration_bin_start:
    parser.error("--duration-bin-end must be greater than "
                 "--duration-bin-start, got "
                 f"{args.duration_bin_end} and {args.duration_bin_start}")

# Create the duration bins:
if args.duration_bin_edges:
    duration_bin_edges = np.array(args.duration_bin_edges)
elif args.duration_bin_spacing == 'log':
    duration_bin_edges = np.logspace(np.log10(args.duration_bin_start),
                                     np.log10(args.duration_bin_end),
                                     args.num_duration_bins + 1)
elif args.duration_bin_spacing == 'linear':
    duration_bin_edges = np.linspace(args.duration_bin_start,
                                     args.duration_bin_end,
                                     args.num_duration_bins + 1)

logging.info("Finding files")

files = [f for f in os.listdir(os.path.join(args.top_directory, args.analysis_date))
         if args.file_identifier in f]

logging.info("{} files found".format(len(files)))

# Add template duration cuts according to the bin inputs
args.template_cuts = args.template_cuts or []
args.template_cuts.append(f"template_duration:{min(duration_bin_edges)}:lower")
args.template_cuts.append(f"template_duration:{max(duration_bin_edges)}:upper_inc")

# Efficiency saving: add SNR cut before any others as sngl_ranking can
# only be less than SNR.
args.trigger_cuts = args.trigger_cuts or []
args.trigger_cuts.insert(0, f"snr:{args.fit_threshold}:lower_inc")

# We can cut out all of the stuff with sngl-ranking
# below threshold now as well
args.trigger_cuts.append(f"{args.sngl_ranking}:{args.fit_threshold}:lower_inc")

logging.info("Setting up the cut dictionaries")
trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info("Setting up duration bins")
tbins = bin_utils.IrregularBins(duration_bin_edges)

# Also calculate live time so that this fitting can be used in rate estimation
# Live time is not immediately obvious - get an approximation with 8 second
# granularity by adding 8 seconds per 'valid' file

live_time = {ifo: 0 for ifo in args.ifos}

logging.info("Getting events which meet criteria")

# Loop through files - add events which meet the immediately gettable
# criteria
date_directory = os.path.join(args.top_directory, args.analysis_date)

if not os.path.exists(date_directory):
    raise FileNotFoundError(f"The directory {date_directory} does not exist.")

files = [f for f in os.listdir(date_directory)
         if args.file_identifier in f]

events = {}
counter = 0

for filename in files:
    counter += 1
    if counter % 1000 == 0:
        logging.info("Processed %d files" % counter)
        for ifo in args.ifos:
            if ifo not in events:
                # In case of no triggers for an extended period
                logging.info("%s: No data", ifo)
            else:
                logging.info("%s: %d triggers in %.0fs", ifo,
                             events[ifo].data['snr'].size, live_time[ifo])

    f = os.path.join(date_directory, filename)
    skipping_file = False
    #If there is an IOerror with the file, don't fail, just carry on
    try:
        h5py.File(f, 'r')
    except IOError:
        logging.info('IOError with file ' + f)
        continue

    # Triggers for this file only:
    triggers = {}
    with h5py.File(f, 'r') as fin:
        # Open the file: does it have the ifo group and snr dataset?
        for ifo in args.ifos:
            if not (ifo in fin and 'snr' in fin[ifo]):
                continue

            # Eventual FIX ME: live output files should (soon) have the live time
            # added, but for now, extract from the filename
            # Format of the filename is to have the live time as a dash,
            # followed by '.hdf' at the end of the filename
            lt = int(f.split('-')[-1][:-4])
            live_time[ifo] += lt

            n_triggers = fin[ifo]['snr'].size
            # Skip if there are no triggers
            if not n_triggers:
                continue

            # Read trigger value datasets from file
            triggers[ifo] = {k: fin[ifo][k][:] for k in fin[ifo].keys()
                             if k not in ('loudest', 'stat', 'gates', 'psd')
                             and fin[ifo][k].size == n_triggers}
            # The n_triggers filter should be enough, but the 'not in'
            # list is for edge cases where the number of gates/loudest etc is
            # the same as the number of triggers

            # The stored chisq is actually reduced chisq, so we need to
            # alter the chisq_dof dataset so we can use the standard conversions.
            # chisq_dof of 1.5 gives the right number (2 * 1.5 - 2 = 1)
            triggers[ifo]['chisq_dof'] = \
                1.5 * np.ones_like(triggers[ifo]['snr'])


    for ifo, trigs_ifo in triggers.items():

        # Apply the cuts to triggers
        keep_idx = cuts.apply_trigger_cuts(trigs_ifo, trigger_cut_dict)

        # triggers contains the datasets that we want to use for
        # the template cuts, so it can be used as the template bank
        # in this use case
        keep_idx = cuts.apply_template_cuts(trigs_ifo, template_cut_dict,
                                            template_ids=keep_idx)

        # Skip if no triggers survive the cuts
        if not keep_idx.size: continue

        # Apply the cuts
        triggers_cut = {k: trigs_ifo[k][keep_idx]
                        for k in trigs_ifo.keys()}

        # Calculate the sngl_ranking value to be used in the fits
        sngls_value = ranking.get_sngls_ranking_from_trigs(
                          triggers_cut, args.sngl_ranking)

        triggers_cut[args.sngl_ranking] = sngls_value

        triggers_da = DictArray(data=triggers_cut)

        # If we are clustering, take the max sngl_ranking value
        if args.cluster:
            # Use sngls_ranking parameter for clustering
            max_idx = sngls_value.argmax()
            triggers_cut = triggers_da.select(max_idx)

        if ifo in events:
            events[ifo] += triggers_da
        else:
            events[ifo] = triggers_da

logging.info("All events processed")

logging.info("Number of events which meet all criteria:")
for ifo in args.ifos:
    if ifo not in events:
        logging.info("%s: No data", ifo)
    else:
        logging.info("%s: %d in %.2fs",
                     ifo, len(events[ifo]), live_time[ifo])

# split the events into bins by template duration
logging.info('Sorting events into template duration bins')

# Fit the triggers within each bin
n_bins = duration_bin_edges.size - 1
alphas = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
counts = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
binned_sngl_stats = {ifo: {} for ifo in args.ifos}

for ifo in events.keys():
    # Sort the events into their bins
    event_bin = np.array([tbins[d]
                          for d in events[ifo].data['template_duration']])

    # For each bin, do the fit
    for bin_num in range(n_bins):
        inbin = event_bin == bin_num
        if not np.count_nonzero(inbin):
            # No triggers, alpha and count are -1
            counts[ifo][bin_num] = -1
            alphas[ifo][bin_num] = -1
            continue

        # For ease, keep the sngl_ranking values as this will be used later
        binned_sngl_stats[ifo][bin_num] = \
            events[ifo].data[args.sngl_ranking][inbin]

        counts[ifo][bin_num] = np.count_nonzero(inbin)
        alphas[ifo][bin_num], _ = trstats.fit_above_thresh(args.fit_function,
                                      binned_sngl_stats[ifo][bin_num],
                                      args.fit_threshold)

logging.info("Writing results")
with h5py.File(args.output, 'w') as fout:
    for ifo in events.keys():
        fout.create_group(ifo)
        # Save the triggers we have used for the fits
        fout[ifo].create_group('triggers')
        for key in events[ifo].data:
            fout[ifo]['triggers'][key] = events[ifo].data[key]

        fout[ifo]['fit_coeff'] = alphas[ifo]
        fout[ifo]['counts'] = counts[ifo]
        fout[ifo].attrs['live_time'] = live_time[ifo]

    fout['bins_upper'] = tbins.upper()
    fout['bins_lower'] = tbins.lower()

    fout.attrs['analysis_date'] = args.analysis_date
    fout.attrs['input'] = sys.argv
    fout.attrs['cuts'] = args.template_cuts + args.trigger_cuts
    fout.attrs['fit_function'] = args.fit_function
    fout.attrs['fit_threshold'] = args.fit_threshold
    fout.attrs['sngl_ranking'] = args.sngl_ranking

logging.info("Done")
