#! python
""" Code to perform photometry on FITS frames

Context : SRP
Module  : SRPTNGPAOLOSourceMatch
Author  : Stefano Covino
Date    : 22/05/2017
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Associate sources on PAOLO polarimeter frames.

Usage   : SRPTNGPAOLOSourceMatch [-h] [-b offset] -f file -i file -I col col col
            col col -o file [-q row row row row row row row row]
            [-s shift shift shift shift shift shift]
            [-S seq seq seq seq] [-t tol] [-v] [--version]
    
            optional arguments:
            -b Offset from bottom frame
            -f Input FITS file
            -i Input photometry file
            -I Positions of Id, X, Y, mag, emag columns in input file
            -o Output FITS file
            -q Limits for each quadrants
            -s Shifts (x and y) between quadrant 1 (the lowest) and the others
            -S Sequence of quadrants (1 is the lowest)
            -t Tolerance in source matching
 
    
History : (08/03/2012) First version.
        : (03/08/2012) Better choice of parameters.
        : (27/09/2012) MJD in tables.
        : (28/11/2012) Minor correction to DATE keywords.
        : (25/09/2013) Latest atpy update.
        : (27/01/2014) New quadrant defaults.
        : (13/02/2017) Python3 porting.
        : (14/02/2017) Minor bug.
        : (19/04/2017) SRPFITS added.
        : (18/05/2017) Minor update.
        : (22/05/2017) No more atpy.
"""

__version__ = '1.1.0'


import argparse, math
import numpy
from astropy.table import Table, Column
import SRPTNG as ST
import SRPTNG.PAOLO as STP
from SRPFITS.Fits.GetHeaderValue import GetHeaderValue
from SRPFITS.Fits.IsFits import IsFits
from SRP.SRPMath.PointMatch import PointMatch
from SRPFITS.Photometry.Counts2Mag import Counts2Mag
from SRPFITS.Photometry.Mag2Counts import Mag2Counts
from SRP.SRPTime.UT2MJD import UT2MJD


# orig 4,2,3,1

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--bottomoffset", action="store", type=float, help="Offset from bottom frame", default=0.0, metavar='offset')
parser.add_argument("-f", "--fitsinputfile", action="store", help="Input FITS file", required=True, metavar='file')
parser.add_argument("-i", "--inputphotometry", action="store", help="Input photometry file", required=True, metavar='file')
parser.add_argument("-I", "--inputcolumns", action="store", nargs=5, type=int, help="Positions of Id, X, Y, mag, emag columns in input file", required=True, metavar='col')
parser.add_argument("-o", "--outfile", action="store", help="Output FITS file", required=True, metavar='file')
parser.add_argument("-q", "--quadrants", action="store", nargs=8, type=float, default=(1, 125, 140, 264, 273, 395, 414, 538),help="Limits for each quadrants", metavar='row')
parser.add_argument("-s", "--shifts", action="store", nargs=6, type=float, help="Shifts (x and y) between quadrant 1 (the lowest) and the others", metavar='shift')
parser.add_argument("-S", "--sequence", action="store", nargs=4, type=int, default = (2,4,1,3), choices=(1,2,3,4), help="Sequence of quadrants (1 is the lowest [Qo Uo Qe Ue])", metavar='seq')
parser.add_argument("-t", "--tolerance", action="store", type=float, help="Tolerance in source matching", default=10.0, metavar='tol')
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
args = parser.parse_args()


#
if not IsFits(args.fitsinputfile):
    parser.error("Invalid input FITS file.")
if args.verbose:
    print("Input FITS file %s" % args.fitsinputfile)
#
if args.verbose:
    print("Input photometry file: %s" % args.inputphotometry)
#
try:
    t = Table.read(args.inputphotometry,format='ascii')
except IOError:
    parser.error("File %s is not readable." % args.inputphotometry)
#
if args.verbose:
    print("\tRead %s entries." % len(t))
#
if args.verbose:
    print("Id column  : %d" % args.inputcolumns[0])
    print("X column   : %d" % args.inputcolumns[1])
    print("Y column   : %d" % args.inputcolumns[2])
    print("mag column : %d" % args.inputcolumns[3])
    print("emag column: %d" % args.inputcolumns[4])
#
try:
    idcol = t.columns.keys()[args.inputcolumns[0]-1]
    xcol = t.columns.keys()[args.inputcolumns[1]-1]
    ycol = t.columns.keys()[args.inputcolumns[2]-1]
    magcol = t.columns.keys()[args.inputcolumns[3]-1]
    emagcol = t.columns.keys()[args.inputcolumns[4]-1]
except IndexError:
    parser.error("Wrong columns numbers.")
#
if args.verbose:
    print("Offset: %.1f" % args.bottomoffset)
q = []
for ii in range(4):
    if args.verbose:
        print("Quadrant %d limits: %.1f %.1f" % (ii+1, args.quadrants[2*ii]-args.bottomoffset, args.quadrants[2*ii+1]-args.bottomoffset))
    q.append(t[(t[ycol] >= (args.quadrants[2*ii]-args.bottomoffset)) & (t[ycol] <= (args.quadrants[2*ii+1]-args.bottomoffset))])
    q[-1].sort(magcol)
    if args.verbose:
        print("\t%d entries in quadrant %d" % (len(q[-1]),ii+1))
#
refset = []
for i in q[0]:
    refset.append((i[xcol],i[ycol]))
#
if args.tolerance > 0.0:
    if args.verbose:
        print("Tolerance is %.1f pixel" % args.tolerance)
else:
    parser.error("Tolerance must be positive.")
#
q1set = []
for ii in range(3):
    objset = []
    for i in q[ii+1]:
        objset.append((i[xcol],i[ycol]))
    #
    if args.shifts:
        q1set.append(PointMatch(refset,objset,args.tolerance,args.shifts[2*ii],args.shifts[2*ii+1],True))
    else:
        q1set.append(PointMatch(refset,objset,args.tolerance))
    if args.verbose:
        print("%d matches between quadrants 1 and %d with shifts %.1f,%.1f" % (len(q1set[-1][0]), ii+2, q1set[-1][1][0], q1set[-1][1][1]))
#
#print q1q2
#print q1q3
#print q1q4
totmatch = []
for i in q1set[0][0]:
    for j in q1set[1][0]:
        for l in q1set[2][0]:
            #print i[0], j[0], l[0]
            if i[0] == j[0] and i[0] == l[0]:
                totmatch.append((i[0],i[1],j[1],l[1]))
#print totmatch

if args.verbose:
    print("%d global matches." % len(totmatch))

ids = []
xs = []
ys = []
mags = []
emags = []
for ii in args.sequence:
    ids.append([])
    xs.append([])
    ys.append([])
    mags.append([])
    emags.append([])
#
for ii in args.sequence:
    for i in totmatch:
        #print i, ii, i[ii-1]
        #
        ids[ii-1].append(q[ii-1][idcol][i[ii-1]])
        xs[ii-1].append(q[ii-1][xcol][i[ii-1]])
        ys[ii-1].append(q[ii-1][ycol][i[ii-1]])
        mags[ii-1].append(q[ii-1][magcol][i[ii-1]])
        emags[ii-1].append(q[ii-1][emagcol][i[ii-1]])
        #
#
totcnt = [0.0 for c in range(len(totmatch))]
etotcnt = [0.0 for c in range(len(totmatch))]
for ii in args.sequence:
    cnt, ecnt = Mag2Counts(mags[ii-1],emags[ii-1])
    totcnt = totcnt + cnt
    etotcnt = etotcnt + ecnt
totmag, etotmag = Counts2Mag(totcnt,etotcnt)
#
tnew = Table()
for ii in args.sequence:
    tnew['%s_%d' % (STP.Id, ii)] = Column(numpy.array(ids[ii-1]), dtype=numpy.int16)
    tnew['%s_%d' % (STP.X, ii)] = Column(numpy.array(xs[ii-1]),unit='pixel', dtype=numpy.float32)
    tnew['%s_%d' % (STP.Y, ii)] = Column(numpy.array(ys[ii-1]),unit='pixel', dtype=numpy.float32)
    tnew['%s_%d' % (STP.Mag, ii)] = Column(numpy.array(mags[ii-1]), dtype=numpy.float32)
    tnew['%s_%d' % (STP.eMag, ii)] = Column(numpy.array(emags[ii-1]), dtype=numpy.float32)
#
tnew[STP.TotMag] = Column(totmag, dtype=numpy.float32)
tnew[STP.eTotMag] = Column(etotmag, dtype=numpy.float32)
#
tnew.meta['FITSFILE'] = args.fitsinputfile
tnew.meta['PHOTFILE'] = args.inputphotometry
msg = ""
for ii in args.sequence:
    msg = msg + "%d " % ii
tnew.meta[STP.SEQUENCE] = msg
for ii in range(4):
    tnew.meta['Q%dLIMU' % (ii+1)] = '%.1f' % args.quadrants[2*ii]
    tnew.meta['Q%dLIMD' % (ii+1)] = '%.1f' % args.quadrants[2*ii+1]
for ii in range(3):
    tnew.meta['SHF1to%dX' % (ii+2)] = q1set[ii][1][0]
    tnew.meta['SHF1to%dY' % (ii+2)] = q1set[ii][1][1]
#
k = GetHeaderValue(args.fitsinputfile,ST.EXPTIME)[0]
if k != None:
    tnew.meta[STP.EXPTIME] = k
    #tnew['TotMag'] = tnew['TotMag'] + 2.5*numpy.log10(float(k))
k = GetHeaderValue(args.fitsinputfile,ST.RADEG)[0]
if k != None:
    tnew.meta[STP.RA] = k*15.0
k = GetHeaderValue(args.fitsinputfile,ST.DECDEG)[0]
if k != None:
    tnew.meta[STP.DEC] = k
k = GetHeaderValue(args.fitsinputfile,ST.POSANG)[0]
if k != None:
    tnew.meta[STP.POSANG] = k
k = GetHeaderValue(args.fitsinputfile,ST.AZ)[0]
if k != None:
    tnew.meta[STP.AZ] = math.degrees(k)
k = GetHeaderValue(args.fitsinputfile,ST.ALT)[0]
if k != None:
    tnew.meta[STP.ALT] = math.degrees(k)
k = GetHeaderValue(args.fitsinputfile,ST.ROTPOS)[0]
if k != None:
    tnew.meta[STP.DEROT] = math.degrees(k)
k = GetHeaderValue(args.fitsinputfile,ST.PARANG)[0]
if k != None:
    tnew.meta[STP.PARANG] = math.degrees(k)
k = GetHeaderValue(args.fitsinputfile,ST.LST)[0]
if k != None:
    tnew.meta['LST'] = k
k = GetHeaderValue(args.fitsinputfile,ST.DATES)[0]
if k != None:
    tnew.meta[STP.DATE] = k
    yy = float(k.split('-')[0])
    me = float(k.split('-')[1])
    dd = float(k.split('-')[2])
else:
    yy = 2000.
    me = 1.
    dd = 1.
k = GetHeaderValue(args.fitsinputfile,ST.TIME)[0]
if k != None:
    tnew.meta[STP.TIME] = k
    hh = float(k.split(':')[0])
    mi = float(k.split(':')[1])
    ss = float(k.split(':')[2])
else:
    hh = 0.
    mi = 0.
    ss = 0.
k = GetHeaderValue(args.fitsinputfile,ST.FILTER)[0]
if k != None:
    tnew.meta[STP.PRISM] = k
k = GetHeaderValue(args.fitsinputfile,ST.GRISM)[0]
if k != None:
    tnew.meta[STP.FILTER] = k
k = GetHeaderValue(args.fitsinputfile,ST.SLIT)[0]
if k != None:
    tnew.meta[STP.PSTOP] = k
k = GetHeaderValue(args.fitsinputfile,ST.OBJECT)[0]
if k != None:
    tnew.meta['OBJECT'] = k
k = GetHeaderValue(args.fitsinputfile,ST.PSLR)[0]
if k != None:
    tnew.meta[STP.POLSLIDE] = k
k = GetHeaderValue(args.fitsinputfile,ST.RTRY1)[0]
if k != None:
    tnew.meta[STP.ROTLAM4] = float(k)
k = GetHeaderValue(args.fitsinputfile,ST.RTRY2)[0]
if k != None:
    tnew.meta[STP.ROTLAM2] = float(k)
#
tnew.meta[STP.MJD] = UT2MJD(yy,me,dd,hh,mi,ss)
#
tnew.write(args.outfile,format='fits',overwrite=True)
if args.verbose:
    print("Results saved in file %s with %d entries" % (args.outfile, len(totmatch)))
else:
    print("%d %s" % (len(totmatch), args.outfile))
#
