#! python
""" Code to derive instrumental Stokes parameters
    
Context : SRP
Module  : SRPTNGPAOLOInstrStokes
Author  : Stefano Covino
Date    : 23/05/2017
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Derive instrumental Q and U Stokes parameters

Usage   : SRPTNGPAOLOInstrStokes [-h] [-a] [-c Q eQ U eU V eV] -f file -o file
            [-v] [--version] [-w wave]
            -a Append data to output
            -c Calibration Q, U and V values
            -f Input FITS photometry file
            -o Output FITS file
            -w Observation wavelength (micron)
    
History : (29/02/2012) First version.
        : (02/08/2012) Minor bug.
        : (27/09/2012) MJD in output.
        : (25/03/2013) Analyze spectra too.
        : (09/04/2013) Bug correction for photometric files.
        : (25/09/2013) Update for lates atpy version.
        : (13/02/2017) Python3 porting.
        : (14/02/2017) Minor bug.
        : (19/04/2017) SRPFITS.
        : (23/05/2017) astropy.table
"""

__version__ = '1.2.0'


import argparse, math, os, sys, warnings
import ephem, numpy
from astropy.table import Table, Column, vstack
import SRPTNG as ST
import SRPTNG.PAOLO as STP
from SRPTNG.PAOLO.AverHourAngle import AverHourAngle
from SRPTNG.PAOLO.AverParallacticAngle import AverParallacticAngle
from SRPTNG.GetObj import GetObj
from SRPTNG.GetTNGSite import GetTNGSite
from SRPFITS.Photometry.Mag2Counts import Mag2Counts



parser = argparse.ArgumentParser()
parser.add_argument("-a", "--append", action="store_true", help="Append data to output")
parser.add_argument("-c", "--calibqu", action="store", type=float, nargs=6, help="Calibration Q, U and V values", metavar=('Q', 'eQ', 'U', 'eU', 'V', 'eV'))
parser.add_argument("-f", "--fitsphotfile", action="store", help="Input FITS photometry/spectroscopy file", required=True, metavar='file')
parser.add_argument("-o", "--outfile", action="store", help="Output FITS file", required=True, metavar='file')
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("-w", "--wave", action="store", type=float, help="Observation wavelength (micron)", metavar='wave')
options = parser.parse_args()


#
try:
    tphot = Table.read(options.fitsphotfile, format='fits')
except IOError:
    parser.error("Invalid input FITS file.")
if options.verbose:
    print("Input FITS photometry file: %s" % options.fitsphotfile)
#
try:
    sequence = tphot.meta[STP.SEQUENCE].split()
    filter = tphot.meta[STP.FILTER]
    ra = tphot.meta[STP.RA]
    dec = tphot.meta[STP.DEC]
    date = tphot.meta[STP.DATE]
    time = tphot.meta[STP.TIME]
    expt = tphot.meta[STP.EXPTIME]
    posang = tphot.meta[STP.POSANG]
    object = tphot.meta[STP.OBJECT]
    pstop = tphot.meta[STP.PSTOP]
    polslide = tphot.meta[STP.POLSLIDE]
    rot4 = tphot.meta[STP.ROTLAM4]
    rot2 = tphot.meta[STP.ROTLAM2]
    mjd = tphot.meta[STP.MJD]
except Exception:
    parser.error("Invalid data in FITS table.")    
#
nospec = False
nophot = False
Fl = []
eFl = []
#
for ii in sequence:
    try:
        Fl.append(tphot[STP.Flux+'_'+ii])
        eFl.append(tphot[STP.eFlux+'_'+ii])
        nophot = True
    except KeyError:
        nospec = True

if nospec:
    for ii in sequence:
        try:
            fl, efl = Mag2Counts(tphot[STP.Mag+'_'+ii], tphot[STP.eMag+'_'+ii])
            nospec = True
        except KeyError:
            nophot = True
        Fl.append(numpy.array(fl))
        eFl.append(numpy.array(efl))
#
if nospec and nophot:
    parser.error("Invalid columns %s,%s in FITS table" % (STP.Mag+'_'+ii, STP.eMag+'_'+ii))
#
warnings.resetwarnings()
warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
Q = (Fl[0]-Fl[2])/(Fl[0]+Fl[2])
U = (Fl[1]-Fl[3])/(Fl[1]+Fl[3])
eQ = numpy.fabs(Q) * numpy.sqrt( ((eFl[0]**2+eFl[2]**2)/(Fl[0]-Fl[2])**2) + ((eFl[0]**2+eFl[2]**2)/(Fl[0]+Fl[2])**2) )
eU = numpy.fabs(U) * numpy.sqrt( ((eFl[1]**2+eFl[3]**2)/(Fl[1]-Fl[3])**2) + ((eFl[1]**2+eFl[3]**2)/(Fl[1]+Fl[3])**2) )
warnings.resetwarnings()
warnings.filterwarnings('always', category=RuntimeWarning, append=True)
#
QNB = numpy.where(numpy.isnan(Q) | numpy.isinf(Q), False, True)
UNB = numpy.where(numpy.isnan(U) | numpy.isinf(U), False, True)
eQNB = numpy.where(numpy.isnan(eQ) | numpy.isinf(eQ), False, True)
eUNB = numpy.where(numpy.isnan(eU) | numpy.isinf(eU), False, True)
QUflag = numpy.where(QNB & eQNB & UNB & eUNB, True, False)
#
Qf = Q[QUflag]
Uf = U[QUflag]
eQf = eQ[QUflag]
eUf = eU[QUflag]
#if polslide.upper().find(STP.LAMBDA2) >= 0:
#    if options.verbose:
#        print "Lambda/2 correction applied..."
#    nQ = []
#    nU = []
#    neQ = []
#    neU = []
#    for el in range(len(Q)):
#        sto = numpy.matrix([1.,Q[el],U[el],0.0]).transpose()
#        nsto = MuellerHalfWavePlateMatrix(math.radians(rot2)).I*sto
#        nQ.append(nsto[1,0])
#        nU.append(nsto[2,0])
#        lQ = GenGaussSet(Q[el],eQ[el],1000)
#        lU = GenGauss, dtype=numpy.float32Set(U[el],eU[el],1000)
#        lsQ = []
#        lsU = []
#        for i in range(1000):
#            sto = numpy.matrix([1.,lQ[i],lU[i],0.0]).transpose()
#            nsto = MuellerHalfWavePlateMatrix(math.radians(rot2)).I*sto
#            lsQ.append(nsto[1,0])
#            lsU.append(nsto[2,0])
#        neQ.append(ScoreatPercentile(lsQ)[3])
#        neU.append(ScoreatPercentile(lsU)[3])
#    Q = nQ
#    eQ = neQ
#    U = nU
#    eU = neU
#
tnew = Table()
tnew[STP.Id] = Column(tphot[STP.Id+'_'+sequence[0]][QUflag], dtype=numpy.int16)
tnew[STP.X] = Column(tphot[STP.X+'_'+sequence[0]][QUflag], dtype=numpy.float32)
tnew[STP.Y] = Column(tphot[STP.Y+'_'+sequence[0]][QUflag], dtype=numpy.float32)
tnew[STP.OBJECT] = Column([object], dtype=numpy.dtype('|S25'))
tnew[STP.Q] = Column(Qf, dtype=numpy.float64)
tnew[STP.eQ] = Column(eQf, dtype=numpy.float64)
tnew[STP.U] = Column(Uf, dtype=numpy.float64)
tnew[STP.eU] = Column(eUf, dtype=numpy.float64)
#tnew[STP.V] = Column([0.0])
#tnew[STP.eV] = Column([0.0])
tnew[STP.POLSLIDE] = Column([polslide], dtype=numpy.dtype('|S10'))
tnew[STP.ROTLAM4] = Column([rot4], dtype=numpy.float32)
tnew[STP.ROTLAM2] = Column([rot2], dtype=numpy.float32)
tnew[STP.MJD] = Column([mjd], dtype=numpy.float64)
if nospec:
    tnew[STP.TotMag] = Column(tphot[STP.TotMag][QUflag], dtype=numpy.float32)
    tnew[STP.eTotMag] = Column(tphot[STP.eTotMag][QUflag], dtype=numpy.float32)
elif nophot:
    tnew[STP.TotFlux] = Column(tphot[STP.TotFlux][QUflag], dtype=numpy.float32)
    tnew[STP.eTotFlux] = Column(tphot[STP.eTotFlux][QUflag], dtype=numpy.float32)
#
warnings.resetwarnings()
warnings.filterwarnings('ignore', category=DeprecationWarning, append=True)
site = GetTNGSite()
nb = GetObj(ra,dec)
site.date = ephem.Date(date+' '+time)
warnings.resetwarnings()
warnings.filterwarnings('always', category=DeprecationWarning, append=True)
#
hourangle = AverHourAngle(nb,site,expt)
tnew[STP.HOURANG] = Column(numpy.array(len(Qf)*[hourangle]), dtype=numpy.float32)
if options.verbose:
    print("Observation hour angle: %.1f" % hourangle)
#
parangle = AverParallacticAngle(nb,site,expt)
tnew[STP.PARANG] = Column(numpy.array(len(Qf)*[parangle]), dtype=numpy.float32)
if options.verbose:
    print("Observation parallactic Angle: %.1f" % parangle)
#
if nospec:
    if options.wave:
        wave = options.wave
    else:
        try:
            wave = ST.LRSFiltCentrWaveDict[filter]
        except KeyError:
            wave = 0.55
    tnew[STP.WAVE] = Column(numpy.array(len(Qf)*[wave]), dtype=numpy.float32)
elif nophot:
    if options.wave:
        wave = options.wave
    else:
        try:
            wave = tphot[STP.WAVE]*1e-4
        except KeyError:
            wave = 0.55
    tnew[STP.WAVE] = Column(wave[QUflag], dtype=numpy.float32)
if options.verbose:
    if nospec:
        print("Observation wavelength: %.3f" % wave)
    elif nophot:
        print("Observayion wavelength: spectral range")
#
tnew[STP.POSANG] = Column([posang], dtype=numpy.float32)
if options.verbose:
    print("Derotator offset: %.1f" % posang)
#
if options.verbose:
    print("Pupil stop: %s" % pstop)
    print("Plate     : %s" % polslide)
    print("Lambda/4  : %.1f" % rot4)
    print("Lambda/2  : %.1f" % rot2)
#

if options.calibqu:
    tnew[STP.CalQ] = Column([options.calibqu[0]], dtype=numpy.float64)
    tnew[STP.eCalQ] = Column([options.calibqu[1]], dtype=numpy.float64)
    tnew[STP.CalU] = Column([options.calibqu[2]], dtype=numpy.float64)
    tnew[STP.eCalU] = Column([options.calibqu[3]], dtype=numpy.float64)
    tnew[STP.CalV] = Column([options.calibqu[4]], dtype=numpy.float64)
    tnew[STP.eCalV] = Column([options.calibqu[5]], dtype=numpy.float64)
    if options.verbose:
        print("Calibrated Stokes parameters added.")
#
if options.append and os.path.exists(options.outfile):
    tnew.write(STP.tempfile,format='fits',overwrite=True)
    tnew2 = Table.read(STP.tempfile,format='fits')
    os.remove(STP.tempfile)
    #
    try:
        tapp = Table.read(options.outfile, format='fits')
    except IOError:
        parser.error("Invalid FITS file to append.")
    #
    ttot = [tapp,tnew2]
#    try:
#        tapp.append(tnew2)
#    except ValueError:
#        parser.error("Tables to be appended are not compatible.")
    vstack(ttot,metadata_conflicts='warn').write(options.outfile,format='fits',overwrite=True)
else:
    tnew.write(options.outfile,format='fits',overwrite=True)
#
if options.verbose:
    print("%d (new) entries saved in file %s" % (len(tnew), options.outfile))
else:
    print("%d %s" % (len(tnew), options.outfile))
#
