""" Code to perform photometry on FITS frames

Context : SRP
Module  : SRPTNGPAOLOSpectrumMatch
Author  : Stefano Covino
Date    : 26/05/2017
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Associate spectra on PAOLO polarimeter frames.

Usage   : SRPTNGPAOLOSpectrumMatch [-h] [-e errspec errspec errspec errspec] -f
            file -o file -s spec spec spec spec
            [-S seq seq seq seq] [-v] [--version]
    
            -e Input FITS error spectra
            -f Input FITS file
            -o Output FITS file
            -s Input FITS spectra
            -S Sequence of quadrants (1 is the lowest [Qo Uo Qe Ue]) 
            -x Cross correlate input spectra
    
History : (27/03/2013) First version.
        : (08/04/2013) Cross-correlation option.
        : (25/09/2013) Update for the latest atpy release.
        : (13/02/2017) Python3 porting.
        : (14/02/2017) Minor bug.
        : (18/05/2017) Minor update.
        : (26/05/2017) Minor update.
"""

__version__ = '1.1.5'


import argparse, math, sys
from scipy.interpolate import interp1d
import atpy, numpy
import SRPTNG as ST
import SRPTNG.PAOLO as STP
from SRPFITS.Fits.GetHeaderValue import GetHeaderValue
from SRPFITS.Fits.IsFits import IsFits
from SRP.SRPTime.UT2MJD import UT2MJD
from SRPFITS.Fits.GetSpectrum import GetSpectrum
from SRP.SRPStatistics.XCorr_1D import XCorr_1D


# orig 4,2,3,1

parser = argparse.ArgumentParser()
parser.add_argument("-e", "--errspectra", action="store", nargs=4, help="Input FITS error spectra", metavar='errspec')
parser.add_argument("-f", "--fitsinputfile", action="store", help="Input FITS file", required=True, metavar='file')
parser.add_argument("-o", "--outfile", action="store", help="Output FITS file", required=True, metavar='file')
parser.add_argument("-s", "--inputspectra", action="store", nargs=4, help="Input FITS spectra", required=True, metavar='spec')
parser.add_argument("-S", "--sequence", action="store", nargs=4, type=int, default = (2,4,1,3), choices=(1,2,3,4), help="Sequence of quadrants (1 is the lowest [Qo Uo Qe Ue])", metavar='seq')
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("-x", "--crosscorr", action="store_true", help="Cross correlate input spectra")
parser.add_argument("--version", action="version", version=__version__)
args = parser.parse_args()


#
if not IsFits(args.fitsinputfile):
    parser.error("Invalid input FITS file.")
if args.verbose:
    print("Input FITS file %s" % args.fitsinputfile)
#
totmatch = []
for s in enumerate(args.inputspectra):
    lm,dt = GetSpectrum(args.inputspectra[s[0]])
    if lm != None and dt != None:
        if args.verbose:
            print("Input FITS spectrum {}: {}".format(s[0]+1, s[1]))
    else:
        parser.error("Invalid input FITS spectrum {}".format(s[1]))
    if args.errspectra:
        lme,dte = GetSpectrum(args.errspectra[s[0]])
        if lme != None and dte != None:
            if args.verbose:
                print("Input FITS error spectrum {}: {}".format(s[0]+1, s[1]))
        else:
            parser.error("Invalid input FITS error spectrum {}".format(s[1]))
        totmatch.append((lm,dt,lme,dte))
    else:
        totmatch.append((lm,dt,lm,numpy.zeros_like(dt)))
#
steps = []
if args.errspectra:
    speclist = args.inputspectra+args.errspectra
else:
    speclist = args.inputspectra
for s in speclist:
    steps.append(GetHeaderValue(s,'CDELT1')[0])
lstep = min(steps)
#
starts = []
for i in totmatch:
    starts.append(i[0][0])
    starts.append(i[2][0])
lstart = min(starts)
#
ends = []
for i in totmatch:
    ends.append(i[0][-1])
    ends.append(i[2][-1])
lend = max(ends)
#
ljoint = numpy.linspace(lstart,lend,numpy.ceil((lend-lstart)/lstep))
itpdspec = []
ssp = []
essp = []
for i in totmatch:
    s1 = interp1d(i[0],i[1],kind='linear', bounds_error=False, fill_value=0.0)
    es1 = interp1d(i[2],i[3],kind='linear', bounds_error=False, fill_value=0.0)
    ssp.append(s1)
    essp.append(es1)
    itpdspec.append((s1(ljoint),es1(ljoint)))
#
xfact = [0.,0.,0.,0.]
if args.crosscorr:
    for i in range(len(itpdspec)):
        if i != 0:
            xfact[i] = XCorr_1D(itpdspec[i][0],itpdspec[0][0],ljoint)
            if args.verbose:
                print("Cross-correlation with first spectrum in sequence: %.3f" % (xfact[i]))
    #
    itpdspec = []
    for i,l in zip(totmatch,list(range(len(xfact)))):
        s1 = interp1d(i[0]+xfact[l],i[1],kind='linear', bounds_error=False, fill_value=0.0)
        es1 = interp1d(i[2]+xfact[l],i[3],kind='linear', bounds_error=False, fill_value=0.0)
        itpdspec.append((s1(ljoint),es1(ljoint)))    
#
totcnt = numpy.zeros_like(ljoint)
etotcnt = numpy.zeros_like(ljoint)
for ii in args.sequence:
    totcnt = totcnt + itpdspec[ii-1][0]
    etotcnt = etotcnt + itpdspec[ii-1][1]
#
tnew = atpy.Table(name=args.outfile)
for ii in args.sequence:
    tnew.add_column('%s_%d' % (STP.Id, ii), numpy.zeros_like(ljoint, dtype=numpy.int16)+1, dtype=numpy.int16)
    tnew.add_column('%s_%d' % (STP.X, ii), numpy.zeros_like(ljoint)+1.0,unit='pixel', dtype=numpy.float32)
    tnew.add_column('%s_%d' % (STP.Y, ii), numpy.zeros_like(ljoint)+1.0,unit='pixel', dtype=numpy.float32)
    tnew.add_column('%s_%d' % (STP.Flux, ii), itpdspec[ii-1][0], dtype=numpy.float32)
    tnew.add_column('%s_%d' % (STP.eFlux, ii), itpdspec[ii-1][1], dtype=numpy.float32)
#
tnew.add_column(STP.TotFlux, totcnt, dtype=numpy.float32)
tnew.add_column(STP.eTotFlux, etotcnt, dtype=numpy.float32)
#
tnew.add_column(STP.WAVE, ljoint, unit='micron', dtype=numpy.float32)
#
msg = ""
for ii in args.sequence:
    msg = msg + "%d " % ii
tnew.add_keyword(STP.SEQUENCE,msg)
#
msg = ""
for ii in args.sequence:
    msg = msg + "%.3f " % xfact[ii-1]
tnew.add_keyword(STP.XCORR,msg)
#
tnew.add_keyword('FITSFILE',args.fitsinputfile)
#
for i in range(len(args.inputspectra)):
    tnew.add_keyword('SPECFL_%d' % (i+1), args.inputspectra[i])
    if args.errspectra:
        tnew.add_keyword('ERRFL_%d' % (i+1), args.errspectra[i])
#
k = GetHeaderValue(args.fitsinputfile,ST.EXPTIME)[0]
if k != None:
    tnew.add_keyword(STP.EXPTIME,k)
    #tnew['TotMag'] = tnew['TotMag'] + 2.5*numpy.log10(float(k))
k = GetHeaderValue(args.fitsinputfile,ST.RADEG)[0]
if k != None:
    tnew.add_keyword(STP.RA,k*15.0)
k = GetHeaderValue(args.fitsinputfile,ST.DECDEG)[0]
if k != None:
    tnew.add_keyword(STP.DEC,k)
k = GetHeaderValue(args.fitsinputfile,ST.POSANG)[0]
if k != None:
    tnew.add_keyword(STP.POSANG,k)
k = GetHeaderValue(args.fitsinputfile,ST.AZ)[0]
if k != None:
    tnew.add_keyword(STP.AZ,math.degrees(k))
k = GetHeaderValue(args.fitsinputfile,ST.ALT)[0]
if k != None:
    tnew.add_keyword(STP.ALT,math.degrees(k))
k = GetHeaderValue(args.fitsinputfile,ST.ROTPOS)[0]
if k != None:
    tnew.add_keyword(STP.DEROT,math.degrees(k))
k = GetHeaderValue(args.fitsinputfile,ST.PARANG)[0]
if k != None:
    tnew.add_keyword(STP.PARANG,math.degrees(k))
k = GetHeaderValue(args.fitsinputfile,ST.LST)[0]
if k != None:
    tnew.add_keyword('LST',k)
k = GetHeaderValue(args.fitsinputfile,ST.DATES)[0]
if k != None:
    tnew.add_keyword(STP.DATE,k)
    yy = float(k.split('-')[0])
    me = float(k.split('-')[1])
    dd = float(k.split('-')[2])
else:
    yy = 2000.
    me = 1.
    dd = 1.
k = GetHeaderValue(args.fitsinputfile,ST.TIME)[0]
if k != None:
    tnew.add_keyword(STP.TIME,k)
    hh = float(k.split(':')[0])
    mi = float(k.split(':')[1])
    ss = float(k.split(':')[2])
else:
    hh = 0.
    mi = 0.
    ss = 0.
k = GetHeaderValue(args.fitsinputfile,ST.FILTER)[0]
if k != None:
    tnew.add_keyword(STP.PRISM,k)
k = GetHeaderValue(args.fitsinputfile,ST.GRISM)[0]
if k != None:
    tnew.add_keyword(STP.FILTER,k)
k = GetHeaderValue(args.fitsinputfile,ST.SLIT)[0]
if k != None:
    tnew.add_keyword(STP.PSTOP,k)
k = GetHeaderValue(args.fitsinputfile,ST.OBJECT)[0]
if k != None:
    tnew.add_keyword('OBJECT',k)
k = GetHeaderValue(args.fitsinputfile,ST.PSLR)[0]
if k != None:
    tnew.add_keyword(STP.POLSLIDE,k)
k = GetHeaderValue(args.fitsinputfile,ST.RTRY1)[0]
if k != None:
    tnew.add_keyword(STP.ROTLAM4,float(k))
k = GetHeaderValue(args.fitsinputfile,ST.RTRY2)[0]
if k != None:
    tnew.add_keyword(STP.ROTLAM2,float(k))
#
tnew.add_keyword(STP.MJD,UT2MJD(yy,me,dd,hh,mi,ss))
#
tnew.write(args.outfile,type='fits',overwrite=True)
if args.verbose:
    print("Results saved in file %s with %d entries" % (args.outfile, len(ljoint)))
else:
    print("%d %s" % (len(ljoint), args.outfile))
#
