""" Code to select FITS tables
Context : SRP.GW
Module  : SRPGWSelect
Author  : Stefano Covino
Date    : 28/03/2019
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Select FITS tables

usage: SRPGWSelect [-h] [-d] -i inputfile -o outfile -s selstr [-v]
                   [--version]

optional arguments:
  -h, --help            show this help message and exit
  -d, --debug           Debug information
  -i inputfile, --inputfile inputfile
                        Input table
  -o outfile, --outfile outfile
                        Output table
  -s selstr, --selstr selstr
                        Selection string
  -v, --verbose         Fully describe operations
  --version             show program's version number and exit


History : (13/11/2015) First version.
		: (02/01/2016) Possibility of asymmetric filtering.
        : (18/01/2016) Different defaul values.
        : (02/02/2016) Sigma clipped statistics.
        : (03/02/2016) Sigma-clipped statistics optional.
        : (06/02/2016) Better binning.
        : (08/02/2016) Sigma clipping at the same level of the chosen sigma.
        : (10/02/2016) Good data management in statistics.
        : (22/02/2916) Good data selectable.
        : (15/07/2016) Magnitude and score selection.
        : (15/12/2016) Save adopted limits.
        : (02/03/2017) Update.
        : (21/12/2017) Better evaluation of magnitude difference.
        : (28/03/2019) Varindices in evaluation.
"""

__version__ = '3.0.0'


import argparse, os, sys
import numpy
from astropy.table import vstack, Table
from astropy import table
from astropy.stats import sigma_clipped_stats
from astropy.stats import histogram
import SRPGW as GW
from SRPGW.GetCommandStr import GetCommandStr
from SRP.SRPStatistics.StDev import StDev
from SRPGW.CheckMagDiff import CheckMagDiff
import warnings
import gc




parser = argparse.ArgumentParser()
parser.add_argument("-c", "--col", action="store", nargs=2, default=(GW.SRPMAG,GW.eSRPMAG), help="Columns to be analyzed", metavar=('mag','emag'))
parser.add_argument("-C", "--clipping", action="store_true", help="Sigma clipped statistics")
parser.add_argument("-d", "--dir", action="store", choices=['all','high','low'], default='all', help="Direction of filtering")
parser.add_argument("-G", "--gooddata", action="store_true", help="Gooddata statistics")
parser.add_argument("-i", "--inputfile", action="store", help="Input table", required=True, metavar='inputfile')
parser.add_argument("-m", "--magsigma", action="store", type=float, default=10., help="n sigma from the average for magnitude selection", metavar='magsigma')
parser.add_argument("-M", "--MagSelection", action="store_true", help="Magnitude selection")
parser.add_argument("-n", "--nbin", action="store", type=int, default=1000, help="n objects per bin", metavar='nbin')
parser.add_argument("-o", "--outfile", action="store", help="Output table", required=True, metavar='outfile')
parser.add_argument("-s", "--scosigma", action="store", type=float, default=10., help="n sigma from the average for score selection", metavar='scosigma')
parser.add_argument("-S", "--ScoSelection", action="store_true", help="Score selection")
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("-x", "--varindexsigma", action="store", type=float, default=5., help="n sigma from the average for variable index selection", metavar='varsigma')
parser.add_argument("-X", "--VarindexSelection", action="store_true", help="Varindex selection")
options = parser.parse_args()


#
if options.inputfile and options.outfile:
    #
    if not os.path.isfile(options.inputfile):
        parser.error ("Input file list does not exist.")
    #
    if not (options.MagSelection or options.ScoSelection):
        parser.error ("Magnitude and/or score selection must be selected.")
    #
    if options.nbin <= 0:
        parser.error ("Number of objects per bin must be greater than 0.")
    #
    if options.magsigma < 0.:
        parser.error ("Number of sigmas for magnitude selection must be greater than 0.")
    #
    if options.scosigma < 0.:
        parser.error ("Number of sigmas for score selection must be greater than 0.")
    #
    if options.varindexsigma < 0.:
        parser.error ("Number of sigmas for varindex selection must be greater than 0.")
    #
    if options.verbose:
        print ("Reading table ", options.inputfile)
    dt = Table.read(options.inputfile, format='ascii.ecsv')
    #
    seltab = []
    #
    if options.MagSelection:
        if options.verbose:
            print ("Magnitude selection...")
        #
        magcol = options.col[0]
        emagcol = options.col[1]
        #
        mags = []
        for e in dt.columns.keys():
            if magcol in e and e.find(magcol) == 0:
                mags.append(e)
        if len(mags) == 0:
            parser.error('Magnitude column(s) not present.')
        #
        emags = []
        for e in dt.columns.keys():
            if emagcol in e and e.find(emagcol) == 0:
                emags.append(e)
        if len(emags) == 0:
            parser.error('Magnitude error column(s) not present.')
        #
        if options.gooddata:
            print ("Good data statistics selected.")
        gdatas = []
        for g in dt.columns.keys():
            if GW.GDATACOL in g:
                gdatas.append(g)
        gdflag = True
        if len(gdatas) == 0:
            gdflag = False
            if options.verbose and options.gooddata:
                print('Good data column(s) not present.')
        if len(gdatas) != len(mags):
            gdflag = False
            if options.verbose and options.gooddata:
                print('Good data and magnitude column(s) do not match.')
        #
        if GW.MAGLIMS in dt.meta:
            if options.verbose:
                print("Pre-coded mag limits detected.")
            datamag = dt.meta[GW.MAGLIMS].split()
            for minp in range(0,len(datamag),6):
                if gdflag and options.gooddata:
                    gooddata = [(dt[mags[int(float(datamag[minp]))]] < 90) & (dt[mags[int(float(datamag[minp+1]))]] < 90) & (dt[gdatas[int(float(datamag[minp]))]] == True) & (dt[gdatas[int(float(datamag[minp+1]))]] == True)]
                else:
                    gooddata = [(dt[mags[int(float(datamag[minp]))]] < 90) & (dt[mags[int(float(datamag[minp+1]))]] < 90)]
                dtpr = dt[gooddata]
                dtpr['_mean'] = (dtpr[mags[int(float(datamag[minp]))]] + dtpr[mags[int(float(datamag[minp+1]))]])/2.
                if len(dtpr['_mean']) == 0:
                    break
                tabl = dtpr[(dtpr['_mean'] >= float(datamag[minp+2])) & (dtpr['_mean'] < float(datamag[minp+3]))]
                #
                stab = tabl[CheckMagDiff(tabl[mags[int(float(datamag[minp]))]],tabl[emags[int(float(datamag[minp]))]],tabl[mags[int(float(datamag[minp+1]))]],
                        tabl[emags[int(float(datamag[minp+1]))]],float(datamag[minp+5]),options.dir)]
                stab.remove_column('_mean')
                seltab.append(stab)
                del tabl
                gc.collect()
            gc.collect()
        else:
            listmaglims = []
            #
            for ii in range(len(mags)):
                for ll in range(ii+1,len(mags)):
                    if options.verbose:
                        print ("%s vs %s..." % (mags[ii], mags[ll]))
                    if gdflag and options.gooddata:
                        gooddata = [(dt[mags[ii]] < 90) & (dt[mags[ll]] < 90) & (dt[gdatas[ii]] == True) & (dt[gdatas[ll]] == True)]
                    else:
                        gooddata = [(dt[mags[ii]] < 90) & (dt[mags[ll]] < 90)]
                    dtpr = dt[gooddata]
                    dtpr['_mean'] = (dtpr[mags[ii]] + dtpr[mags[ll]])/2.
                    if len(dtpr['_mean']) == 0:
                        break
                    mincol = numpy.min(dtpr['_mean'])
                    maxcol = numpy.max(dtpr['_mean'])
                    if options.verbose:
                        print ("\tMag max: %.3f, mag min: %.3f" % (mincol, maxcol))
                    ilims = list(range(0,len(dtpr['_mean']),options.nbin))
                    dtpr.sort('_mean')
                    lims = [dtpr['_mean'][i] for i in ilims]
                    lims[-1] = maxcol
                    #lims = numpy.linspace(mincol,maxcol,options.nbin+1)
                    for l in range(len(lims)-1):
                        if options.verbose:
                            print ("\tInterval from %.3f to %.3f..." % (lims[l],lims[l+1]))
                        tabl = dtpr[(dtpr['_mean'] >= lims[l]) & (dtpr['_mean'] < lims[l+1])]
                        warnings.filterwarnings('ignore', category=UserWarning, append=True)
                        warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
                        if options.clipping:
                            mean, median, std = sigma_clipped_stats(tabl[mags[ii]] - tabl[mags[ll]], sigma=options.magsigma, iters=5)
                        else:
                            mean = numpy.mean(tabl[mags[ii]] - tabl[mags[ll]])
                            std = numpy.std(tabl[mags[ii]] - tabl[mags[ll]])
                        #
                        stab = tabl[CheckMagDiff(tabl[mags[ii]],tabl[emags[ii]],tabl[mags[ll]],tabl[emags[ll]],mean+options.magsigma*std,options.dir)]
                        if options.dir == 'all':
                            if options.verbose:
                                print ("\t\tMean: %.3f, std: %.3f. Limits at %.3f and %.3f. # objects: %d" % (mean,std,mean-options.magsigma*std,mean+options.magsigma*std, len(stab)))
                        elif options.dir == 'high':
                            if options.verbose:
                                print ("\t\tMean: %.3f, std: %.3f. Limits at %.3f. # objects: %d" % (mean,std,mean+options.magsigma*std, len(stab)))
                        elif options.dir == 'low':
                            if options.verbose:
                                print ("\t\tMean: %.3f, std: %.3f. Limits at %.3f. # objects: %d" % (mean,std,mean-options.magsigma*std, len(stab)))
                        warnings.resetwarnings()
                        stab.remove_column('_mean')
                        listmaglims.append((ii,ll,lims[l],lims[l+1],mean-options.magsigma*std,mean+options.magsigma*std))
                        seltab.append(stab)
                        del tabl
                        gc.collect()
                    del gooddata, dtpr
                    gc.collect()
    #
    if options.ScoSelection:
        if options.verbose:
            print ("Score selection...")
        #
        if GW.MINMAG not in dt.columns.keys():
            parser.error("Minimum magnitudes not available.")
        #
        if GW.SCOCOL not in dt.columns.keys():
            parser.error("Score not available.")
        #
        if GW.SCOLIMS in dt.meta:
            if options.verbose:
                print("Pre-coded score limits detected.")
            datasco = dt.meta[GW.SCOLIMS].split()
            for sinp in range(0,len(datasco),3):
                tabl = dt[(dt[GW.MINMAG] >= float(datasco[sinp])) & (dt[GW.MINMAG] < float(datasco[sinp+1]))]
                stab = tabl[tabl[GW.SCOCOL] >= float(datasco[sinp+2])]
                seltab.append(stab)
                del tabl
                gc.collect()
            gc.collect()
        else:
            listscolims = []
            ilims = list(range(0,len(dt[GW.MINMAG]),options.nbin))
            dt.sort(GW.MINMAG)
            lims = [dt[GW.MINMAG][i] for i in ilims]
            maxcol = numpy.max(dt[GW.MINMAG])
            lims[-1] = maxcol
            #
            for l in range(len(lims)-1):
                if options.verbose:
                    print ("\tScore interval from %.3f to %.3f..." % (lims[l],lims[l+1]))
                tabl = dt[(dt[GW.MINMAG] >= lims[l]) & (dt[GW.MINMAG] < lims[l+1])]
                warnings.filterwarnings('ignore', category=UserWarning, append=True)
                warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
                if options.clipping:
                    mean, median, std = sigma_clipped_stats(tabl[GW.SCOCOL], sigma=options.scosigma, iters=5)
                else:
                    mean = numpy.mean(tabl[GW.SCOCOL])
                    std = numpy.std(tabl[GW.SCOCOL])
                stab = tabl[(tabl[GW.SCOCOL] - mean) >= (options.scosigma*std)]
                warnings.resetwarnings()
                if options.verbose:
                    print ("\t\tScore mean: %.3f, std: %.3f. Limits at %.3f. # objects: %d" % (mean,std,mean+options.scosigma*std, len(stab)))
                listscolims.append((lims[l],lims[l+1],mean+options.scosigma*std))
                seltab.append(stab)
                del tabl
                gc.collect()
            gc.collect()
    #
    if options.VarindexSelection:
        if options.verbose:
            print ("Varindex selection...")
        vlistscolims = []
        #
        if GW.MINMAG not in dt.columns.keys():
            parser.error("Minimum magnitudes not available.")
        #
        for v in GW.fedescr:
            if options.verbose:
                print ("'{}' selection...".format(v['name']))
            #
            if v['name'] in dt.meta:
                if options.verbose:
                    print("Pre-coded '{}' limits detected.".format(v['name']))
                datasco = dt.meta[v['name']].split()
                for sinp in range(0,len(datasco),3):
                    tabl = dt[(dt[GW.MINMAG] >= float(datasco[sinp])) & (dt[GW.MINMAG] < float(datasco[sinp+1]))]
                    stab = tabl[tabl[v['name']] >= float(datasco[sinp+2])]
                    seltab.append(stab)
                    del tabl
                    gc.collect()
                gc.collect()
            else:
                v['lims'] = []
                ilims = list(range(0,len(dt[GW.MINMAG]),options.nbin))
                dt.sort(GW.MINMAG)
                lims = [dt[GW.MINMAG][i] for i in ilims]
                maxcol = numpy.max(dt[GW.MINMAG])
                lims[-1] = maxcol
                #
                for l in range(len(lims)-1):
                    if options.verbose:
                        print ("\t'%s' interval from %.3f to %.3f..." % (v['name'],lims[l],lims[l+1]))
                    tabl = dt[(dt[GW.MINMAG] >= lims[l]) & (dt[GW.MINMAG] < lims[l+1])]
                    warnings.filterwarnings('ignore', category=UserWarning, append=True)
                    warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
                    if options.clipping:
                        mean, median, std = sigma_clipped_stats(tabl[v['name']], sigma=options.varindexsigma, iters=5)
                    else:
                        mean = numpy.mean(tabl[v['name']])
                        std = numpy.std(tabl[v['name']])
                    stab = tabl[(tabl[v['name']] - mean) >= (options.varindexsigma*std)]
                    warnings.resetwarnings()
                    if options.verbose:
                        print ("\t\t'%s' mean: %.3f, std: %.3f. Limits at %.3f. # objects: %d" % (v['name'],mean,std,mean+options.scosigma*std, len(stab)))
                    v['lims'].append((lims[l],lims[l+1],mean+options.varindexsigma*std))
                    seltab.append(stab)
                    del tabl
                    gc.collect()
                vlistscolims.append(v['lims'])
                gc.collect()
#
    try:
        restab = vstack(seltab)
    except TypeError:
        print("No sources selected (try to reduce the number of objects per bin).")
        sys.exit(0)
    #
    # Remove duplicate events
    #
    restabo = table.unique(restab,GW.IDCOL)
    #
    if not (GW.MAGLIMS in dt.meta) and options.MagSelection:
        msg = ''
        for i in listmaglims:
            msg = msg + "%.3f %.3f %.3f %.3f %.3f %.3f " % (i[0],i[1],i[2],i[3],i[4],i[5])
        restabo.meta[GW.MAGLIMS] = msg
    if not (GW.SCOLIMS in dt.meta) and options.ScoSelection:
        msg = ''
        for i in listscolims:
            msg = msg + "%.3f %.3f %.3f " % (i[0],i[1],i[2])
        restabo.meta[GW.SCOLIMS] = msg
    for v in enumerate(GW.fedescr):
        if not (v[1]['name'] in dt.meta) and options.VarindexSelection:
            msg = ''
            for i in vlistscolims[v[0]]:
                msg = msg + "%.3f %.3f %.3f " % (i[0],i[1],i[2])
            restabo.meta[v[1]['name']] = msg
    #
    if options.verbose:
        print ("Saving results...")
    restabo.write(options.outfile,format='ascii.ecsv', overwrite=True)
    if options.verbose:
        print ("Table %s with %d entries saved." % (options.outfile, len(restabo)))
    else:
        print (options.outfile, len(restabo))
else:
    parser.print_help()
#
