#!/usr/bin/env python

# Copyright 2015 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import kstest

from pylal import SnglInspiralUtils as sniuls, trigger_fits as trstats, rate

from pycbc import io, events
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    #"effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5)
}

def get_stat(statchoice, snr, rchisq, fac):
    if fac is not None:
        if statchoice not in ["new_snr", "effective_snr"]:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    else:
        return stat_dict[statchoice](snr, rchisq)

def get_bins(opt, pmin, pmax):
    if opt.bin_spacing == "linear":
        return rate.LinearBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "log":
        return rate.LogarithmicBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "irregular":
        return rate.IrregularBins([float(b) for b in \
                                   opt.irregular_bins.split(",")])

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                "distributions to various functions")

parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
#parser.add_argument("-V", "--verbose", action="store_true",
#                    help="Print extra debugging information.", default=False)
parser.add_argument("--inputs", nargs="+",
                    help="Input file or space-separated list of input files "
                    "containing single triggers.  Currently .xml(.gz) and "
                    ".sqlite formats supported.  Required")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ".  Required")
parser.add_argument("--plot-dir", default=None,
                    help="Plot the fits made, the variation of fitting "
                    "coefficients and the Kolmogorov-Smirnov test values "
                    "and save the plots to the specified directory.")
parser.add_argument("--user-tag", default="",
                    help="Put a possibly informative string in the names of "
                    "plots files.")
parser.add_argument("--ifos", nargs="+",
                    help="Ifo or space-separated list of ifos to select "
                    "triggers to be fit.  Required.")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr"],
                    help="Function of SNR and chisq to perform fits with.")
#parser.add_argument("--chisq-choice", default="power",
#                    choices=["power","bank","auto","maxpowerbank"],
#                    help="Which chisq value or values to form the "
#                    "single-trigger statistic with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics.  Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr.")
parser.add_argument("--snr-threshold", default=5.5,
                    help="Only fit triggers with SNR above this threshold.")
parser.add_argument("--stat-threshold", nargs="+", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold : can be a space-separated list, then a fit "
                    "will be done for each threshold.  Required.  Typical "
                    "values 6.5 6.75 7")
parser.add_argument("--fit-param", required=True,
                    help="Parameter over which to estimate variation of the "
                    "fits. Required. Should be a column in the "
                    "SnglInspiralTable")
# FIXME - or a math function of columns. Ex. 1./mtotal")
parser.add_argument("--bin-spacing", choices=["linear", "log", "irregular"],
                    help="How to space parameter bin edges.")
binopt = parser.add_mutually_exclusive_group(required=True)
binopt.add_argument("--num-bins", type=int,
                    help="Number of regularly spaced bins to use over the "
                    " parameter.")
binopt.add_argument("--irregular-bins",
                    help="Comma-separated list of parameter bin edges. "
                    "Required if --bin-spacing = irregular")

opt = parser.parse_args()

# convert list of strings into floats (?)
#opt.stat_threshold = [float(t) for t in opt.stat_threshold]
statname = opt.sngl_stat.replace("_", " ")
paramname = opt.fit_param.replace("_", " ")

#NOT SURE WHAT TO USE LOGGING STUFF HERE FOR ...
#if opt.verbose:
#    log_level = logging.DEBUG
#else:
#    log_level = logging.WARN
#logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
if opt.plot_dir is not None:
    outdir = opt.plot_dir if opt.plot_dir.endswith('/') else opt.plot_dir

outfile = open(opt.output, 'w')

## Check option logic
if opt.bin_spacing == "irregular" and opt.irregular_bins is None:
    raise RuntimeError("Must specify a list of irregular bin edges!")

if opt.inputs[0].split('.')[-1] == "sqlite":
    trigformat = "sqlite"
    table = "sngl_inspiral"
    if len(opt.inputs) > 1:
        raise RuntimeError("Sorry, can't deal with more than one .sqlite "
                           "input file")
    columns = ['ifo', 'end_time', 'end_time_ns', 'snr', 'chisq',
               'chisq_dof', 'bank_chisq', 'bank_chisq_dof', 'cont_chisq',
               'cont_chisq_dof', 'mchirp', 'eta', 'mtotal', 'mass1',
               'mass2', 'template_duration']
    if opt.fit_param not in columns:
        columns.append(opt.fit_param)
elif opt.inputs[0].split('.')[-1] == "xml" or \
       opt.inputs[0].split('.')[-2:] == ["xml", "gz"]:
    trigformat = "xml"
else: 
    raise RuntimeError("I didn't recognize the input file format!")

# initialize result storage
parbins = {}
counts = {}
templates = {}
fits = {}
stdev = {}
ks_prob = {}

histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]

# header
outfile.write("# ifo threshold lower upper templates triggers alpha sig_alpha ks_prob\n")

for ifo in opt.ifos:
    # FIXME - Would like a uniform input interface from trigger files,
    # including HDF5, to the 'get_column' syntax used by SnglInspiralUtils
    if trigformat == "sqlite":
        sngls = io.sqlite.TableData(table, columns)
        sngls.connect_db(opt.inputs[0])
        sngls.restrict_rows(conditions=["ifo='"+str(ifo)+"'",
                            "snr>"+str(opt.snr_threshold)])
        sngls.retrieve_rows()
    elif trigformat == "xml":
        sngls = sniuls.ReadSnglInspiralFromFiles(opt.inputs,
            filterFunc=lambda s: s.snr>opt.snr_threshold)
        sngls.ifocut(ifo, inplace=True)

    if opt.plot_dir:
        plotbase = opt.plot_dir + ifo + '-' + opt.user_tag

    snr = sngls.get_column('snr')
    if not len(snr):
        print >>sys.stderr, 'skipping', ifo
        continue
    chisq = sngls.get_column('chisq')
    chisq_dof = sngls.get_column('chisq_dof')
    mass1 = sngls.get_column('mass1')
    mass2 = sngls.get_column('mass2')
    parvals = sngls.get_column(opt.fit_param)
    rchisq = chisq / (2 * chisq_dof - 2)
   # bank_rchisq = sngls.get_column('bank_chisq') / (26 * 2)  # hack since pycbc doesn't output correct bank chisq dof
   # max_rchisq = np.maximum(rchisq, bank_rchisq)
    statv = get_stat(opt.sngl_stat, snr, rchisq, opt.stat_factor)

    plotrange = np.linspace(0.95 * min(statv), 1.05 * max(statv), 100)
    pmin = min(parvals)
    pmax = max(parvals)
    pbins = get_bins(opt, 0.99 * pmin, 1.01 * pmax)
    # list of bin indices
    binind = [pbins[c] for c in pbins.centres()]
    # assign indices to parameter values based on which bin they lie in
    pind = np.array([pbins[par] for par in parvals])

    templates[ifo] = {}
    counts[ifo] = {}
    fits[ifo] = {}
    stdev[ifo] = {}
    ks_prob[ifo] = {}

    # determine trigger counts first to get plot limits
    for th in opt.stat_threshold:
        counts[ifo][th] = {}
        for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
            vals_inbin = statv[pind == i]
            counts[ifo][th][i] = sum(vals_inbin >= th)
    maxcount = max([counts[ifo][th][i] for i in binind for th in 
                    opt.stat_threshold])

    # now do fits and plot them
    for th in opt.stat_threshold:
        fits[ifo][th] = {}
        stdev[ifo][th] = {}
        ks_prob[ifo][th] = {}

        for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
            # determine number of templates generating the triggers involved
            mass1_inbin = mass1[pind == i]
            mass2_inbin = mass2[pind == i]
            mass_tuples = [(m1, m2) for m1, m2 in zip(mass1_inbin, mass2_inbin)]
            tmpls = set(mass_tuples)
            templates[ifo][i] = len(tmpls)
            vals_inbin = statv[pind == i]
    #if len(valsinBin) < nmax:
    #  print 'Not enough triggers in bin', i, '- skipping this one!'
    #  continue
    #vals_inbin.sort()
    ## calculate threshold via Nth loudest trig
    #thresh = valsinBin[-nmax]
    #numoverThresh = nmax

            # do the fit
            alpha, sig_alpha = trstats.fit_above_thresh(
                                              opt.fit_function, vals_inbin, th)
    #alpha, sig_alpha = fit_exponential_nmax(valsinBin, N=nmax)
            fits[ifo][th][i] = alpha
            stdev[ifo][th][i] = sig_alpha
            _, ks_prob[ifo][th][i] = trstats.KS_test(
                                       opt.fit_function, vals_inbin, alpha, th)
            outfile.write("%s %.2f %.3g %.3g %d %d %.3f %.3f %.3g\n" % 
              (ifo, th, lower, upper, len(tmpls), counts[ifo][th][i], alpha,
               sig_alpha, ks_prob[ifo][th][i]))

            # add histogram to plot
            if opt.plot_dir:
                histcounts, edges = np.histogram(vals_inbin, bins=50)
                cum_counts = histcounts[::-1].cumsum()[::-1]
                binlabel = r"%.2g - %.2g" % (lower, upper)
                plt.semilogy(edges[:-1], cum_counts, linewidth=2, 
                             color=histcolors[i], label=binlabel, alpha=0.6)

                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha, th),
                  '--', color=histcolors[i],
                  label=r'$\alpha = $%.2f $\pm$ %.2f' % (alpha, sig_alpha))
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha + \
                  sig_alpha, th), ':', alpha=0.6, color=histcolors[i])
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha - \
                  sig_alpha, th), ':', alpha=0.6, color=histcolors[i])
 
        if opt.plot_dir:
            leg = plt.legend(labelspacing=0.2)
            plt.setp(leg.get_texts(), fontsize=11)
            plt.ylim(0.7, 2*maxcount)
            plt.xlim(0.9*min(opt.stat_threshold), max(plotrange))
            plt.grid()
            plt.title(ifo + " " + statname + " distribution split by " + \
                  paramname)
            plt.xlabel(statname, size="large")
            plt.ylabel("Cumulative number", size="large")
            plt.savefig(plotbase + '_' + opt.sngl_stat + '_cdf_by_' + \
              paramname[0:3] + '_fit_thresh_' + str(th) + '.png')
            plt.close()

# make plots of alpha and KS significance for ifos having triggers
if opt.plot_dir:
    for ifo in fits.keys():
        for th in opt.stat_threshold:
            plt.errorbar(pbins.centres(), [fits[ifo][th][i] for i in binind],
              yerr=[stdev[ifo][th][i] for i in binind], fmt='+-', 
              label=ifo + ' fit above %.2f' % th)
        if opt.bin_spacing == 'log': plt.semilogx()
        plt.grid()
        plt.legend(loc='best')
        plt.xlabel(paramname, size='large')
        plt.ylabel("fit parameter ('alpha')", size='large')
        plt.savefig(plotbase + '_alpha_vs_' + paramname[0:3] + '.png')
        plt.close()

        for th in opt.stat_threshold:
            plt.plot(pbins.centres(), [ks_prob[ifo][th][i] for i in binind], 
              '+--', label=ifo+' KS prob, thresh %.2f' % th)
        if opt.bin_spacing == 'log': plt.loglog()
        else : plt.semilogy()
        plt.grid()
        leg = plt.legend(loc='best', labelspacing=0.2)
        plt.setp(leg.get_texts(), fontsize=11)
        plt.xlabel(paramname, size='large')
        plt.ylabel('KS test p-value')
        plt.savefig(plotbase + '_KS_prob_vs_' + paramname[0:3] + '.png')
        plt.close()

print 'Done!'
