#! python
""" Code to derive model parameters
    
Context : SRP
Module  : SRPTNGPAOLOParamFit
Author  : Stefano Covino
Date    : 17/04/2019
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Derive TNG polarimetric model parameters.

Usage   : SRPTNGPAOLOParamFit [-h] [-a n k] [-d detoff] [-f file] -i file
                           [-o file] [-s syst] [-v] [--version] [-z q0 u0 v0]
            -a Aluminium refractive and extinction coefficient multiplicative factors
            -d Detector offset (deg)
            -f Input FITS Stokes parameter file
            -i Input normalized instrumental Stokes parameter file
            -o Output fit parameter FITS file
            -s Systematic error to be added
            -z Normalized instrumental polarization

    
History : (01/03/2012) First version.
        : (27/09/2012) Better output.
        : (29/11/2012) Correct sign for position angle.
        : (04/12/2012) Possibility to choose the fit parameters.
        : (31/03/2013) Better help message.
        : (24/07/2013) Total intensity in polarization computation.
        : (25/09/2013) Update for the latest atpy version.
        : (13/02/2017) Python3 porting.
        : (14/02/2017) Minor bug.
        : (15/02/2017) Astropy porting.
        : (17/04/2019) Different statistical functions.
"""

__version__ = '0.4.0'


import argparse
import atpy, numpy
from scipy.optimize import minimize
from scipy import stats
from SRP.SRPStatistics.GenFitPars import GenFitPars
from SRP.SRPPolarimetry.AluminiumRefractiveIndex import AluminiumRefractiveIndex
import SRPTNG.PAOLO as STP
from SRPTNG.PAOLO.TNGMuellerMatrix import TNGMuellerMatrix
from SRPTNG.PAOLO.TNGMuellerMatrixPlate2 import TNGMuellerMatrixPlate2
from SRPTNG.PAOLO.TNGMuellerMatrixPlate4 import TNGMuellerMatrixPlate4
from SRPTNG.PAOLO.StokesOffsetVector import StokesOffsetVector
from astropy.table import Table, Column

    


parser = argparse.ArgumentParser()
parser.add_argument("-a", "--alum", action="store", nargs=2, type=float, help="Aluminium refractive and extinction coefficient multiplicative factors", metavar=('n','k'),default=(0.9,0.9))
parser.add_argument("-c", "--choice", action="store", nargs=6, type=int, help="Choices for fitting: n, k, detoff, q0, u0, v0 [default=(1 -1 1 1 1 1)]", metavar=('n','k','detoff','q0','u0','v0'), default=(1,-1,1,1,1,0))
parser.add_argument("-d", "--detoff", action="store", type=float, help="Detector offset (deg)", metavar='detoff',default=0.5)
parser.add_argument("-f", "--fitsfile", action="store", help="Input FITS Stokes parameter file", metavar='file')
parser.add_argument("-i", "--instrpolfile", action="store", help="Input normalized instrumental Stokes parameter file", metavar='file', required=True)
parser.add_argument("-o", "--outfile", action="store", help="Output fit parameter FITS file", metavar='file')
parser.add_argument("-s", "--syst", action="store", type=float, default=0.0, help="Systematic error to be added", metavar=('syst'))
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("-z", "--zero", action="store", nargs=3, type=float, help="Normalized instrumental polarization", metavar=('q0','u0','v0'),default=(0.,0.,0.))
options = parser.parse_args()


#
try:
    #dtp = atpy.Table(options.instrpolfile, type='fits')
    dtp = Table.read(options.instrpolfile, format='fits')
except IOError:
    parser.error("Invalid input instrumental FITS Stokes parameter file.")
if options.verbose:
    print("Input instrumental Stokes parameter file: %s" % options.instrpolfile)
#
lambd2 = False
lambd4 = False
for i in dtp:
    try:
        if i[STP.POLSLIDE].upper().find(STP.LAMBDA2) >= 0:
            lambd2 = True
            break
        elif i[STP.POLSLIDE].upper().find(STP.LAMBDA4) >= 0:
            lambd4 = True
            break
    except IndexError:
        parser.error("Table %s format not corrected." % options.instrpolfile)
#        
if options.fitsfile:
    try:
        dt = Table.read(options.fitsfile, format='fits')
    except IOError:
        parser.error("Invalid input Stokes parameter file.")
    if options.verbose:
        print("Input FITS Stokes parameter file: %s" % options.fitsfile)
    #
    try:
        nn = dt[STP.N][0]
        kk = dt[STP.K][0]
        offoff = dt[STP.DETOFF][0]
        q0q0 = dt[STP.Q0][0]
        u0u0 = dt[STP.U0][0]
        v0v0 = dt[STP.V0][0]
    except IndexError:
        parser.errpr("FITS Stokes parameter file without the expected entries.")
else:
    nn = options.alum[0]
    kk = options.alum[1]
    offoff = options.detoff
    q0q0 = options.zero[0]
    u0u0 = options.zero[1]
    v0v0 = options.zero[2]
#
if q0q0 > 1. or u0u0 > 1 or v0v0 > 1 or (q0q0**2 + u0u0**2 + v0v0**2) > 1:
    parser.error("Unrealistic instrumental polarization.")
if nn <= 0.0 or kk <= 0:
    parser.error("Multiplicative factors must be positive.")
#
if options.syst < 0.0:
    parser.error("Systematic error must be positive.")
#
dtp[STP.eQ] = numpy.sqrt((dtp[STP.eQ])**2 + options.syst**2)
dtp[STP.eU] = numpy.sqrt((dtp[STP.eU])**2 + options.syst**2)
#
if options.verbose:
    print("Refractive index multiplicative factor      : %.3f" % nn)
    print("Extinction coefficient multiplicative factor: %.3f" % kk)
    print("Detector offset (deg)                       : %.2f" % offoff)
    print("Instrumental polarization                   : Q0=%.3g, U0=%.3g, V0=%.3g" % (q0q0, u0u0, v0v0))
    if options.syst > 0:
        print("Systematic error                            : %.3g" % options.syst)
    print("Fit rules                                   : ", str(options.choice).strip('[]'))
#
def func (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cq = vars[1]
    cu = vars[2]
    pa = vars[3]
    p = vars[4]
    fn = pari[0]
    fk = pari[1]
    off = -p+pari[2]
    q0 = pari[3]
    u0 = pari[4]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, cq, cu, 0.0]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrix(pa,fn*n,fk*k,off)*Stokes+StokesOffsetVector(q0,u0,0.0)
    return s
#
def func2 (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cq = vars[1]
    cu = vars[2]
    pa = vars[3]
    p = vars[4]
    rot = vars[5]
    fn = pari[0]
    fk = pari[1]
    off = -p+pari[2]
    q0 = pari[3]
    u0 = pari[4]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, cq, cu, 0.0]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrixPlate2(pa,fn*n,fk*k,rot,off)*Stokes+StokesOffsetVector(q0,u0,0.0)
    return s
#
def func4 (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cv = vars[1]
    pa = vars[2]
    p = vars[3]
    rot = vars[4]
    fn = parsi[0]
    fk = pari[1]
    off = -p+pari[2]
    v0 = pari[3]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, 0.0, 0.0, cv]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrixPlate4(pa,fn*n,fk*k,rot,off)*Stokes+StokesOffsetVector(0.0,0.0,v0)
    return s
#
def chi2 (pars, args):
    chiq = 0.0
    chiu = 0.0
    chiv = 0.0
    for i in dtp:
        try:
            if lambd2:
                s = func2((i[STP.WAVE],i[STP.CalQ],i[STP.CalU],i[STP.PARANG],i[STP.POSANG],i[STP.ROTLAM2]),pars,args)
                chiq = chiq + (((s[1,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
                chiu = chiu + (((s[2,0]/s[0,0]-i[STP.U])/i[STP.eU])**2)
            elif lambd4: 
                s = func4((i[STP.WAVE],i[STP.CalV],i[STP.PARANG],i[STP.POSANG],i[STP.ROTLAM4]),pars,args)
                chiv = chiv + (((s[3,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
            else:
                s = func((i[STP.WAVE],i[STP.CalQ],i[STP.CalU],i[STP.PARANG],i[STP.POSANG]),pars,args)
                chiq = chiq + (((s[1,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
                chiu = chiu + (((s[2,0]/s[0,0]-i[STP.U])/i[STP.eU])**2)
        except IndexError:
            parser.error("Table %s format not corrected." % options.instrpolfile)
    return chiq + chiu + chiv
#
if lambd4:
    inizio = [nn,kk,offoff,v0v0]
    ags = [(options.choice[0],inizio[0]),(options.choice[1],inizio[1]),(options.choice[2],inizio[2]),(options.choice[5],inizio[3])]
else:
    inizio = [nn,kk,offoff,q0q0,u0u0]
    ags = [(options.choice[0],inizio[0]),(options.choice[1],inizio[1]),(options.choice[2],inizio[2]),(options.choice[3],inizio[3]),(options.choice[4],inizio[4])]
#
parst = minimize (chi2, inizio, args=(ags,), method='Nelder-Mead', options={'disp':False}, tol=1e-4)
pars = GenFitPars(parst.x,ags)
tchi = chi2(pars,ags)
npr = sum([iii[0]>0 for iii in ags])
if lambd4:
    ndf = len(dtp)-npr
else:
    ndf = 2*len(dtp)-npr
#
if options.verbose:
    print("Fit reduced CHI2, dof, CHI2: %.2f %d %.2f" % ((tchi/ndf), ndf, tchi))
    print("Fit probability            : %.2f%%" % (100*stats.chi2.cdf(float(tchi),ndf)))
#
if options.verbose:
    print("Fit refractive index multiplicative factor      : %.3f" % pars[0])
    print("Fit extinction coefficient multiplicative factor: %.3f" % pars[1])
    print("Fit Detector offset (deg)                       : %.2f" % pars[2])
    if lambd4:
        print("Fit instrumental polarization                   : V0=%.3g" % (pars[3]))
    else:
        print("Fit instrumental polarization                   : Q0=%.3g, U0=%.3g" % (pars[3], pars[4]))
else:
    if lambd4:
        print("%.3f %.3f %.2f %.2g %.3f" % (pars[0], pars[1], pars[2], pars[3], (tchi/ndf)))
    else:
        print("%.3f %.3f %.2f %.2g %.2g %.3f" % (pars[0], pars[1], pars[2], pars[3], pars[4], (tchi/ndf)))
#
if options.outfile:
    tout = Table()
    tout.add_column(Column(numpy.array([pars[0]]),STP.N))
    tout.add_column(Column(numpy.array([pars[1]]),STP.K))
    tout.add_column(Column(numpy.array([pars[2]]),STP.DETOFF))
    if lambd4:
        tout.add_column(Column(numpy.array([0.0]),STP.Q0))
        tout.add_column(Column(numpy.array([0.0]),STP.U0))
        tout.add_column(Column(numpy.array(pars[3]),STP.V0))
    else:
        tout.add_column(Column(numpy.array([pars[3]]),STP.Q0))
        tout.add_column(Column(numpy.array([pars[4]]),STP.U0))
        tout.add_column(Column(numpy.array([0.0]),STP.V0))
    tout.add_column(Column(numpy.array([tchi/ndf]),STP.CHI2))
    tout.write(options.outfile,format='fits',overwrite=True)
    if options.verbose:
        print("Fit parameters saved in file %s" % options.outfile)
    else:
        print("%s" % options.outfile)   
#
