""" Code to show tables
Context : SRP.GW
Module  : SRPGWCandSelect
Author  : Stefano Covino
Date    : 05/04/2019
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Data Analysis

usage: SRPGWTabExtract [-h] [-d dmagcol] [-F [magfiles [magfiles ...]]]
                       [-i inputfile] -r resfile [-v] [--version]

optional arguments:
  -h, --help            show this help message and exit
  -d dmagcol, --dmagcol dmagcol
                        Delta(mag) col
  -F [magfiles [magfiles ...]], --findata [magfiles [magfiles ...]]
                        Collect final data
  -i inputfile, --inputfile inputfile
                        Input table
  -r resfile, --resfile resfile
                        Result table (created if it does not exist)
  -v, --verbose         Fully describe operations
  --version             show program's version number and exit


History : (25/01/2016) First version.
        : (02/02/2016) Better management of file pathes.
        : (04/02/2016) Better format for the csv file.
        : (06/02/2016) Bug correction.
        : (08/02/2016) Better csv format.
        : (09/02/2016) A comma was missing in the csv file.
        : (10/02/2016) MJD in csv file and good data information.
        : (11/02/2016) More flexible with good data info. Starting point selectable.
        : (12/02/2016) Header in csv file.
        : (13/02/2016) Simbad as info.
        : (17/02/2016) Better plots.
        : (25/02/2016) PSF mags.
        : (02/03/2016) Neightbor information.
        : (04/03/2016) Slightly larger light-curve time axis plot and candidate score.
        : (06/03/2016) Minor improvement.
        : (14/03/2016) Always PSF magnitudes in output, if available.
        : (17/05/2016) Weights no more required.
        : (20/05/2016) MJD start.
        : (07/06/2016) Improved score.
        : (15/07/2016) Bug correction for MJD start.
        : (30/11/2016) No good data area shown in plots.
        : (05/12/2016) Better plot and galaxy information.
        : (07/12/2016) Show test statistics.
        : (18/12/2016) FWHM is shown.
        : (23(12/2016) Chi2 is shown.
        : (24/12/2016) Chi2 in score.
        : (02/03/2017) Update.
        : (03/03/2017) FITS viewer.
        : (05/03/2019) GLADE information.
        : (27/03/2019) Minor bug correction.
        : (31/03/2019) Select maximum number of pics.
        : (05/04/2019) Minor bug in number of selected pics.
"""

__version__ = '1.11.1'


import argparse, os, os.path, time
import numpy
from astropy.table import Column, Table
from astropy import table
import pylab
from matplotlib import gridspec
import SRPGW as GW
from SRPGW.Score import Score
from SRP.SRPTime.UT2MJD import UT2MJD
from SRPGW.Chi2 import Chi2



def nth_mag (maglist, nth):
    m = numpy.array(maglist)
    m.sort()
    if nth == 0:
        return 1000
    else:
        ms = m[m < 80]
        if ms is not None:
            if len(ms) >= nth:
                return ms[nth]
            else:
                return ms[-1]
        else:
            return m[nth]



parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dmagcol", action="store", help="Delta(mag) col", default=(GW.VARINDCOL),metavar='dmagcol')
parser.add_argument("-f", "--fits", action="store", nargs=2, help="See FITS frames", metavar=('zoom','scale'))
parser.add_argument("-i", "--inputfile", action="store", help="Input table", required=True, metavar='inputfile')
parser.add_argument("-l", "--lightcurves", action="store_true", help="Generate light-curve files")
parser.add_argument("-m", "--mjdevt", action="store", type=float, help="MJD of the event", metavar='mjdevt')
parser.add_argument("-n", "--nstart", action="store", type=int, help="Starting entry to visualize", metavar='nstart')
parser.add_argument("-o", "--outlist", action="store", help="Output file for candidate list", metavar='outlist')
parser.add_argument("-r", "--resfile", action="store", help="Result table (created if it does not exist)", required=True, metavar='resfile')
parser.add_argument("-s", "--selpics", action="store", type=int, default = 0, help="Select maximum number of pics to show (the brighest)", metavar='selpics')
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
options = parser.parse_args()




#
if os.path.isfile(options.inputfile):
    if options.verbose:
        print ("Reading table ", options.inputfile)
    dt = Table.read(options.inputfile, format='ascii.ecsv')
else:
    parser.error ("Input file does not exist.")
#
if options.nstart:
    if options.nstart <= 0. or options.nstart > len(dt):
        parser.error("Not acceptable starting point.")
    else:
        nstart = options.nstart - 1
else:
    nstart = 0
#
if options.selpics < 0:
    parser.error("Maximum number of pics must be positive.")
#
if options.fits:
    import pyds9
    try:
        fzoom = int(options.fits[0])
    except ValueError:
        parser.error("FITS stamp zoom incorrect.")
    if fzoom <= 0:
        parser.error("FITS stamp zoom must be positive.")
    fscale = options.fits[1]
#
prestime = time.localtime()
mjdnow = UT2MJD(float(prestime[0]),float(prestime[1]),float(prestime[2]),float(prestime[3]),float(prestime[4]),float(prestime[5]))
if options.mjdevt:
    if options.mjdevt <= 0 and options.mjdevt >= mjdnow:
        parser.error("MJD must be positive or in the past.")
#
mags = []
for e in dt.columns.keys():
    if GW.SRPMAG in e and e.find(GW.SRPMAG) == 0:
        mags.append(e)
if len(mags) == 0:
    parser.error('Magnitude column(s) not present.')
#
emags = []
for e in dt.columns.keys():
    if GW.eSRPMAG in e:
        emags.append(e)
if len(emags) == 0:
    parser.error('Magnitude error column(s) not present.')
if len(mags) != len(emags):
    parser.error('Magnitude and magnitude error column(s) do not match.')
#
weights = []
for e in dt.columns.keys():
    if GW.WEIGHTCOL in e and GW.WEIGHTACOL not in e and GW.FNAMECOLW not in e:
        weights.append(e)
weflag = True
if len(weights) == 0 or len(weights) != len(mags):
    weflag = False
    print('Weights column(s) not present.')
#
weightareas = []
for e in dt.columns.keys():
    if GW.WEIGHTACOL in e:
        weightareas.append(e)
    waflag = True
if len(weightareas) == 0 or len(weightareas) != len(mags):
    waflag = False
    print('Weightareas column(s) not present.')
#
gdatas = []
for g in dt.columns.keys():
    if GW.GDATACOL in g:
        gdatas.append(g)
gdflag = True
if len(gdatas) == 0 or len(gdatas) != len(mags):
    gdflag = False
    if options.verbose:
        print('Good data column(s) not present.')
#
tstats = []
for t in dt.columns.keys():
    if GW.TESTCOL in t:
        tstats.append(t)
tsflag = True
if len(tstats) == 0 or len(tstats) != len(mags):
    tsflag = False
#    if options.verbose:
#        print('Test statistics column(s) not present.')
#
fwhms = []
for w in dt.columns.keys():
    if GW.FWCOL in w:
        fwhms.append(w)
fwflag = True
if len(fwhms) == 0 or len(fwhms) != len(mags):
    fwflag = False
    if options.verbose:
        print('FWHM column(s) not present.')
#
epochs = []
for e in dt.columns.keys():
    if GW.DATECOL in e:
        epochs.append(e)
if len(epochs) == 0:
    parser.error('Epoch column(s) not present.')
#
pics = []
for p in dt.columns.keys():
    if GW.PICCOL in p:
        pics.append(p)
if len(pics) == 0:
    parser.error('Picture column(s) not present.')
#
fits = []
for f in dt.columns.keys():
    if GW.FITSCOL in f:
        fits.append(f)
if len(fits) == 0 and options.outlist:
    parser.error('FITS column(s) not present.')
#
xpixs = []
for x in dt.columns.keys():
    if GW.XCOL in x:
        xpixs.append(x)
if len(xpixs) == 0:
    parser.error('X pixel column(s) not present.')
#
ypixs = []
for y in dt.columns.keys():
    if GW.YCOL in y:
        ypixs.append(y)
if len(ypixs) == 0:
    parser.error('Y pixel column(s) not present.')
#
mpls = []
for m in dt.columns.keys():
    if GW.MPCOL in m:
        mpls.append(m)
mplflag = True
if len(mpls) == 0 or len(mpls) != len(mags):
    mplflag = False
    if options.verbose:
        print('Minor planet column(s) not present.')
#
pmags = []
for p in dt.columns.keys():
    if GW.PSFMAG in p and p.find(GW.PSFMAG) == 0:
        pmags.append(p)
psfflag = True
if len(pmags) == 0:
    print('PSF magnitude column(s) not present.')
    psfflag = False
#
pemags = []
for p in dt.columns.keys():
    if GW.ePSFMAG in p:
        pemags.append(p)
epsfflag = True
if len(pemags) == 0:
    print('PSF magnitude error column(s) not present.')
    epsfflag = False
if len(pmags) != len(pemags):
    print('PSF magnitude and magnitude error column(s) do not match.')
    epsfflag = False
#
if GW.LEDACOL in dt.columns.keys():
    ledaflag= True
else:
    ledaflag = False
#
if GW.GLADECOL in dt.columns.keys():
    gladeflag= True
else:
    gladeflag = False

#
if options.fits:
    ffiles = []
    for f in dt.columns.keys():
        if GW.FITSCOL in f:
            ffiles.append(f)
    if len(ffiles) == 0:
        parser.error('FITS stamp column(s) not present.')
    if len(mags) != len(ffiles):
        parser.error('Magnitude and FITS stamp column(s) do not match.')
    try:
        ds9 = pyds9.DS9()
    except TypeError:
        parser.error("DS9 executable not found.")
    ds9.set("zoom %d" % fzoom)
    ds9.set("scale %s" % fscale)
    ds9.set("tile")
    for i in range(len(ffiles)-1):
        ds9.set("frame new")
#
if options.verbose:
    print ("Generating variability index...")
if psfflag:
    c = [dt[i] for i in pmags]
else:
    c = [dt[i] for i in mags]
lM = numpy.apply_along_axis(numpy.max, 0, c)
lm = numpy.apply_along_axis(numpy.min, 0, c)
mn = numpy.apply_along_axis(numpy.mean, 0, c)
dt[GW.VARINDCOL] = lM-lm
dt[GW.AVEMAG] = mn
dt[GW.MINMAG] = lm
dt.sort(GW.MINMAG)
#
res = Chi2(dt)
if GW.CHI2 in dt.columns.keys():
    dt.remove_column(GW.CHI2)
dt[GW.CHI2] = res
#
if GW.CHI2 in dt.columns.keys():
    schi2 = dt[GW.MINMAG]
else:
    schi2 = -99
#
if GW.LEDACOL in dt.columns.keys():
    sleda = dt[GW.LEDACOL]
else:
    sleda = -99
#
if GW.GLADECOL in dt.columns.keys():
    glade = dt[GW.GLADECOL]
else:
    glade = -99
#
if GW.MINMAG in dt.columns.keys() and GW.VARINDCOL in dt.columns.keys() and GW.NEICOL in dt.columns.keys() and GW.DNEICOL in dt.columns.keys():
    dt[GW.SCOCOL] = Score(dt[GW.MINMAG],dt[GW.VARINDCOL],dt[GW.NEICOL],dt[GW.DNEICOL],sleda,schi2)
dt.sort(GW.SCOCOL)
dt.reverse()
#
if os.path.isfile(options.resfile):
    if options.verbose:
        print ("Reading table ", options.resfile)
    dto = Table.read(options.resfile, format='ascii.ecsv')
else:
    dto = Table(dt.columns)
    dto.remove_rows(slice(0,len(dto)))
#
pylab.ion()
#
for en in enumerate(dt[nstart:]):
    print ("Processing entry %d of %d:" % (en[0]+1+nstart,len(dt)))
    magvals = []
    emagvals = []
    epvals = []
    wvals = []
    wavals = []
    tvals = []
    pvals = []
    #fvals = []
    xvals = []
    yvals = []
    gvals = []
    mvals = []
    fwvals = []
    pmagvals = []
    pemagvals = []
    ffilesvals = []
    for i in range(len(mags)):
        magvals.append(en[1][mags[i]])
        emagvals.append(en[1][emags[i]])
        epvals.append(en[1][epochs[i]])
        if weflag:
            wvals.append(en[1][weights[i]])
        if waflag:
            wavals.append(en[1][weightareas[i]])
        pvals.append(en[1][pics[i]])
        xvals.append(en[1][xpixs[i]])
        yvals.append(en[1][ypixs[i]])
        if gdflag:
            gvals.append(en[1][gdatas[i]])
        if tsflag:
            tvals.append(en[1][tstats[i]])
        if fwflag:
            fwvals.append(en[1][fwhms[i]])
        if mplflag:
            mvals.append(en[1][mpls[i]])
        if psfflag and epsfflag:
            pmagvals.append(en[1][pmags[i]])
            pemagvals.append(en[1][pemags[i]])
        if options.fits:
            ffilesvals.append(en[1][ffiles[i]])
#
    if options.verbose:
        if psfflag:
            mlim = nth_mag (pmagvals, options.selpics)
        else:
            mlim = nth_mag (magvals, options.selpics)
        #
        p = pylab.figure(figsize=(20,7))
        gs = gridspec.GridSpec(2,1)
        if options.selpics == 0:
            gssu = gridspec.GridSpecFromSubplotSpec(1, len(epvals), subplot_spec=gs[0,0])
        else:
            gssu = gridspec.GridSpecFromSubplotSpec(1, options.selpics, subplot_spec=gs[0,0])
        gssd = gridspec.GridSpecFromSubplotSpec(1, len(epvals), subplot_spec=gs[1,0])
        #
        pcnt = 0
        for ii in range(len(pvals)):
            #print(pcnt,mlim)
            if (psfflag and pmagvals[ii] < mlim) or (magvals[ii] < mlim):
                px = p.add_subplot(gssu[0,pcnt])
                try:
                    fn = str(pvals[ii]).split('"')[1]
                    ppx = pylab.imread(fn)
                    pylab.imshow(ppx)
                except IndexError:
                    pass
                except FileNotFoundError:
                        pass
                pylab.gca().axes.get_xaxis().set_ticks([])
                pylab.gca().axes.get_yaxis().set_ticks([])
                pcnt = pcnt + 1
        px = p.add_subplot(gssd[0,:])
        pylab.text(0.5,0.5,str(en[1][GW.IDCOL]),transform=px.transAxes)
        #
        idxgood = numpy.array(emagvals) <= 0.5
        idxupp = numpy.array(emagvals) > 0.5
        xdata = numpy.array(epvals) - epvals[0]
        xdatagood = numpy.array(epvals)[idxgood] - epvals[0]
        xdataupp = numpy.array(epvals)[idxupp] - epvals[0]
        if len(numpy.array(magvals)[idxgood]) > 0:
            px.errorbar(xdatagood,numpy.array(magvals)[idxgood],numpy.array(emagvals)[idxgood],fmt='o',color='blue')
        if len(numpy.array(magvals)[idxupp]) > 0:
            px.plot(xdataupp,numpy.array(magvals)[idxupp],'v',color='red',ms=6)
        #
        if psfflag and epsfflag:
            idxgood = numpy.array(pemagvals) <= 0.5
            idxupp = numpy.array(pemagvals) > 0.5
            xdata = numpy.array(epvals) - epvals[0]
            xdatagood = numpy.array(epvals)[idxgood] - epvals[0]
            xdataupp = numpy.array(epvals)[idxupp] - epvals[0]
            if len(numpy.array(pmagvals)[idxgood]) > 0:
                px.errorbar(xdatagood,numpy.array(pmagvals)[idxgood],numpy.array(pemagvals)[idxgood],fmt='o',color='magenta')
            if len(numpy.array(pmagvals)[idxupp]) > 0:
                px.plot(xdataupp,numpy.array(pmagvals)[idxupp],'v',color='orange',ms=6)
        #
        if pylab.ylim()[1] > 25.0:
            pylab.ylim((pylab.ylim()[0],25.0))
        if pylab.ylim()[0] < 10.0:
            pylab.ylim((10,pylab.ylim()[1]))
        pylab.ylim(pylab.ylim()[::-1])
        if options.mjdevt:
            minmjd = options.mjdevt
        else:
            minmjd = epvals[0]
        plength = max(epvals)-minmjd
        pylab.xlim((-plength/20.,plength*1.05))
        pylab.xlabel ("Days from the 1st epoch (MJD=%.3f)" % minmjd)
        pylab.ylabel ("Magnitudes")
        pylab.subplots_adjust(bottom=0.10,wspace=0.0,hspace=0.05,top=0.97,left=0.05,right=0.97)
        #
        if gdflag:
            for iii in range(len(gvals)):
                if gvals[iii] == 0:
                    pylab.axvline(xdata[iii],color='red',lw=15,alpha=0.10,linestyle='--')
        #
        pylab.draw()
        #
        if options.fits:
            ds9.set("frame first")
            for f in ffilesvals:
                ds9.set("file %s" % f)
                ds9.set("frame next")
        #
        msg = en[1][GW.IDCOL]
        msg = msg + "\nEpochs:"
        for e in epvals:
            msg = msg + "\t%.5f" % e
        #
        msg = msg + "\nMags:"
        for m in magvals:
            msg = msg + "\t%.2f" % m
        #
        msg = msg + "\neMags:"
        for m in emagvals:
            msg = msg + "\t%.2f" % m
        #
        if psfflag and epsfflag:
            msg = msg + "\nPSF Mags:"
            for m in pmagvals:
                msg = msg + "\t%.2f" % m
        #
            msg = msg + "\nPSF eMags:"
            for m in pemagvals:
                msg = msg + "\t%.2f" % m
        #
        if weflag:
            msg = msg + "\nWeights:"
            for w in wvals:
                msg = msg + "\t%.4f" % w
        #
        if gdflag:
            msg = msg + "\nGood data:"
            for g in gvals:
                msg = msg + "\t%d" % g
        #
        if tsflag:
            msg = msg + "\nTest stat:"
            for t in tvals:
                msg = msg + "\t%.1f" % t
        #
        msg = msg + "\nXYPixel:"
        for x,y in zip(xvals,yvals):
            msg = msg + "\t%.2f %.2f" % (x,y)
        #
        if fwflag:
            msg = msg + "\nFWHM:"
            for w in fwvals:
                msg = msg + "\t%.1f" % w
        #
        if mplflag:
            msg = msg + "\nMinor planets:"
            for m in mvals:
                msg = msg + "\t%s" % m
        #
        if ledaflag and en[1][GW.LEDACOL] != -99:
            msg = msg + "\nNearby HyperLEDA galaxy: %s" % en[1][GW.LEDACOL]
        #
        if gladeflag and en[1][GW.GLADECOL] != -99:
            msg = msg + "\nNearby GLADE galaxy: %s" % en[1][GW.GLADECOL]
        #
        if GW.NEICOL in dt.columns.keys():
            msg = msg + "\nClosest neighbor: %.1f arcsec, " % en[1][GW.NEICOL]
        #
        if GW.DNEICOL in dt.columns.keys():
            msg = msg + "delta(mag) neighbor: %.1f\n" % en[1][GW.DNEICOL]  
        else:
            msg = msg + '\n'
        #
        if GW.MINMAG in dt.columns.keys() and GW.VARINDCOL in dt.columns.keys() and GW.CHI2 in dt.columns.keys():
            msg = msg + "Min mag: %.2f, Variability index: %.1f, chi2/dof: %.1f\n" % (en[1][GW.MINMAG], en[1][GW.VARINDCOL], en[1][GW.CHI2])
        #
        if GW.SCOCOL in dt.columns.keys():
            msg = msg + "Score: %.1f\n" % en[1][GW.SCOCOL]
        #
        if GW.SIMBCOL in dt.columns.keys():
            if en[1][GW.SIMBCOL] != 'No':
                msg = msg + "Simbad: %s\n" % en[1][GW.SIMBCOL]
        #
        print (msg)
        answ = input("Keep it (y/N/e)? ")
        print()
        if answ.upper() == 'Y':
            try:
                dto.add_row(en[1])
            except ValueError:
                parser.error("Wrong format for table ", options.resfile)
            if not os.path.isdir(GW.RESDIR):
                os.mkdir(GW.RESDIR)
            pylab.savefig(os.path.join(GW.RESDIR,str(en[1][GW.IDCOL])+'.png'))
        elif answ.upper() == 'E':
            break
        pylab.close()
    #
pylab.ioff()
#
if len(dto) > 0:
    dto.sort(GW.IDCOL)
    ndto = table.unique(dto,GW.IDCOL)
    #
    ndto.write(options.resfile,format='ascii.ecsv', overwrite=True)
    if options.verbose:
        print ("Table %s with %d entries saved." % (options.resfile, len(ndto)))
    else:
        print (options.resfile, len(ndto))
else:
    ndto = dto
#
if options.fits:
    ds9.set("exit")
#
if options.outlist:
    if GW.MINMAG in ndto.columns.keys():
        ndto.sort(GW.MINMAG)
    else:
        ndto.sort(GW.SRPMAG+'_1')
    #
    g = open(options.outlist,'w')
    g.write("#Id,Field,RA,DEC,X,Y,NEpochs,MJDSTART,Mag,Emag.Pics,FITS"+os.linesep)
    for en in ndto:
        fld = os.path.splitext(os.path.basename(en[GW.FNAMECOLF+'_1']))[0].split('_')[-1]
        msg = "%s,%s,%.5f,%.5f,%.3f,%.3f,%d," % (en[GW.IDCOL],fld,en[GW.RACOL],en[GW.DECCOL],en[GW.XCOL+'_1'],en[GW.YCOL+'_1'],len(mags))
        for i in range(len(mags)):
            if psfflag and epsfflag:
                mg = en[pmags[i]]
                emg = en[pemags[i]]
            else:
                mg = en[mags[i]]
                emg = en[emags[i]]
            msg = msg + "%.5f,%.2f,%.2f,%s,%s" % (en[epochs[i]],mg,emg,en[pics[i]],en[fits[i]])
            if i != len(mags)-1:
                msg = msg + ','
        msg = msg+os.linesep
        g.write(msg)
    g.close()
    if options.verbose:
        print("Output list file saved in ", options.outlist)
#
if options.lightcurves:
    if not os.path.isdir(GW.LCDIR):
        os.mkdir(GW.LCDIR)
    #
    if GW.MINMAG in dto.columns.keys():
        ndto.sort(GW.MINMAG)
    else:
        ndto.sort(GW.SRPMAG+'_1')
    #
    for en in ndto:
        g = open(os.path.join(GW.LCDIR,en[GW.IDCOL])+'.lc','w')
        for i in range(len(epochs)):
            if psfflag and epsfflag:
                mg = en[pmags[i]]
                emg = en[pemags[i]]
            else:
                mg = en[mags[i]]
                emg = en[emags[i]]
            g.write("%.5f %.3f %.3f\n" % (en[epochs[i]], mg, emg))
        g.close()
#




