#!/usr/bin/env python                    
# python3 pycbc_template_kde_calc.py --template-file /home/praveen.kumar/focused_search/kde/sbbh2.hdf --signal-file /home/praveen.kumar/focused_search/kde/3OGC_PEtable.txt_ifar_gt0.5.csv   --min-mass 5     --fit-param mchirp eta chi_eff   --log-param  True False False    --output-file   kdevals_signal.hdf  --make-signal-kde  --nfold-signal 2   --verbose

import numpy, h5py, operator, argparse, logging
from pycbc import init_logging
import pycbc.conversions as convert
from pycbc import libutils
from pycbc.events import triggers
akde = libutils.import_optional('awkde')
kf = libutils.import_optional('sklearn.model_selection')

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--signal-file')
parser.add_argument('--template-file', required=True)
parser.add_argument('--min-mass', type=float, default=None,
                    help='Used only on signal masses: remove all' 
                         'signal events with mass2 < min_mass')
parser.add_argument('--nfold-signal', type=int,
                    help='Number of k-folds for signal KDE cross validation')
parser.add_argument('--nfold-template', type=int,
                    help='Number of k-folds for template KDE cross validation')
parser.add_argument('--fit-param', nargs='+', required=True,
                    help='Parameters over which KDE is calculated')
parser.add_argument('--log-param', nargs='+', choices=['True', 'False'], required=True)
parser.add_argument('--output-file', required=True)
parser.add_argument('--make-signal-kde', action='store_true')
parser.add_argument('--make-template-kde', action='store_true')
parser.add_argument('--verbose', action='count')
args = parser.parse_args()
init_logging(args.verbose)

assert len(args.fit_param) == len(args.log_param)
if args.make_signal_kde and args.make_template_kde:
    parser.error("Choose only one option out of --make-signal-kde and --make-template-kde")

def kde_awkde(x, x_grid, alp=0.5, gl_bandwidth='silverman', ret_kde=False):
    kde = akde.GaussianKDE(glob_bw=gl_bandwidth, alpha=alp, diag_cov=True)
    kde.fit(x)
    if isinstance(x_grid, (list, tuple, numpy.ndarray)) == False:
        y = kde.predict(x_grid)
    else:
        y = kde.predict(x_grid)
    if ret_kde == True:
        return kde, y
    return y


def kfcv_awkde(sample, bwchoice, alphachoice, k=2):
    """
    Evaluate the K-fold cross validated log likelihood for an awKDE with
    specific bandwidth and sensitivity (alpha) parameters
    """
    if bwchoice not in ['silverman', 'scott']:
        bwchoice = float(bwchoice)
    fomlist = []
    kfold = kf.KFold(n_splits=k, shuffle=True, random_state=None)
    for train_index, test_index in kfold.split(sample):
        train, test = sample[train_index], sample[test_index]
        y = kde_awkde(train, test, alp=alphachoice, gl_bandwidth=bwchoice)
        fomlist.append(numpy.sum(numpy.log(y)))
    return numpy.mean(fomlist)


def optimizedparam(sampleval, nfold=2):
    # Test wide range of bandwidths in addition to standard values
    bwgrid = ['scott', 'silverman'] + numpy.logspace(-2,0,10).tolist()
    # Fix adaptive sensitivity parameter to 1
    alphagrid = [1]
    FOM = {}
    for gbw in bwgrid:
        for alphavals in alphagrid:
            FOM[(gbw, alphavals)] = kfcv_awkde(sampleval, gbw, alphavals, k=nfold)
    optval = max(FOM.items(), key=operator.itemgetter(1))[0]
    optbw, optalpha = optval[0], optval[1]
    maxFOM = FOM[(optbw, optalpha)]
    return optbw, optalpha


# Obtaining template parameters
temp_file = h5py.File(args.template_file, 'r')
mass1 = temp_file['mass1']
tid = numpy.arange(len(mass1))  # Array of template ids
mass_spin = triggers.get_mass_spin(temp_file, tid)

f_dest = h5py.File(args.output_file, 'w')
template_pars = []
for param, slog in zip(args.fit_param, args.log_param):
    pvals = triggers.get_param(param, args, *mass_spin)
    # Write the KDE param values to output file
    f_dest.create_dataset(param, data=pvals)
    if slog in ['False']:
        logging.info('Using param: %s', param)
        template_pars.append(pvals)
    elif slog in ['True']:
        logging.info('Using log param: %s', param)
        template_pars.append(numpy.log(pvals))
    else:
        raise ValueError("invalid log param argument, use 'True', or 'False'")

# Copy standard data to output file
f_dest.attrs['fit_param'] = args.fit_param
f_dest.attrs['log_param'] = args.log_param
with h5py.File(args.template_file, "r") as f_src:
    f_src.copy(f_src["./"], f_dest["./"], "input_template_params")
temp_samples = numpy.vstack((template_pars)).T

if args.make_template_kde:
    logging.info('Starting optimization of template KDE parameters')
    optbw, optalpha = optimizedparam(temp_samples, nfold=args.nfold_template)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating template KDE')
    template_kde = kde_awkde(temp_samples, temp_samples, alp=optalpha, gl_bandwidth=optbw)
    f_dest.create_dataset("template_kde", data=template_kde)

# Obtaining signal parameters
if args.make_signal_kde:
    signal_pars = []
    signal_file = numpy.genfromtxt(args.signal_file, dtype=float,
                                   delimiter=',', names=True)
    mass2_sgnl = signal_file['mass2']
    N_original = len(mass2_sgnl)
    if args.min_mass:
        idx = mass2_sgnl > args.min_mass
        mass2_sgnl = mass2_sgnl[idx]
        logging.info('%i triggers out of %i with MASS2 > %s' %
                         (len(mass2_sgnl), N_original, str(args.min_mass)))
    else:
        idx = numpy.full(N_original, True)
    mass1_sgnl = signal_file['mass1'][idx]
    assert min(mass1_sgnl - mass2_sgnl) > 0
 
    for param, slog in zip(args.fit_param, args.log_param):
        pvals = signal_file[param][idx]
        if slog in ['False']:
            logging.info('Using param: %s', param)
            signal_pars.append(pvals)
        elif slog in ['True']:
            logging.info('Using log param: %s', param)
            signal_pars.append(numpy.log(pvals))
        else:
            raise ValueError("invalid log param argument, use 'True', or 'False'")
    
    signal_samples = numpy.vstack((signal_pars)).T
    logging.info('Starting optimization of signal KDE parameters')  
    optbw, optalpha = optimizedparam(signal_samples, nfold=args.nfold_signal)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating signal KDE')
    signal_kde = kde_awkde(signal_samples, temp_samples, alp=optalpha, gl_bandwidth=optbw)

    f_dest.create_dataset("signal_kde", data=signal_kde)

f_dest.close()
logging.info('Done!')
