""" Code to match FITS tables
    
Context : SRP.GW
Module  : SRPGWUnionMatch
Author  : Stefano Covino
Date    : 19/12/2017
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Coordinate match of FITS tables

SRPGWUnionMatch -husage: SRPGWMatch [-h] [-c RA DEC] -f firstcat -o outfilematch outfiledisapp
                  outfileapp -r radius -s secondcat [-v] [--version]

optional arguments:
  -h, --help            show this help message and exit
  -c RA DEC, --coordcols RA DEC
                        Coordinate column labels
  -f firstcat, --firstcat firstcat
                        First table
  -o outfilematch outfiledisapp outfileapp, --outfile outfilematch outfiledisapp outfileapp
                        Output tables for matched, disappeared and appeared
                        sources
  -r radius, --radius radius
                        Radius for matching (arcsec)
  -s secondcat, --secondcat secondcat
                        Second table
  -v, --verbose         Fully describe operations
  --version             show program's version number and exit


History : (16/01/2016) First version.
        : (18/01/2016) Units in SkyCoord.
        : (29/01(2016) Possibility to work on already adequate tables (dual catalogues, etc.)
        : (02/02/2016) Better path management.
        : (17/02/2016) unique before saving.
        : (17/05/2016) Id prefix can be chosen.
        : (02/03/2017) Update.
        : (19/12/2017) Columns to be added.
"""

__version__ = '2.2.0'


import argparse, os
import numpy
from astropy.table import hstack, Table, Column, unique
from astropy.coordinates import SkyCoord
from astropy.coordinates import search_around_sky
from astropy import units as u
import SRPGW as GW




parser = argparse.ArgumentParser()
parser.add_argument("-a", "--addcols", action="store", nargs='*', help="Columns to be added to the output table", metavar='addcols')
parser.add_argument("-c", "--coordcols", action="store", nargs=2, default=('X_WORLD','Y_WORLD'), help="Coordinate column labels", metavar=('RA','DEC'))
parser.add_argument("-i", "--inputcats", action="store", nargs='*', help="Input catalogues", required=True, metavar='cat')
parser.add_argument("-I", "--Id", action="store", help="Id prefix", default=GW.VSTID, metavar='Idpref')
parser.add_argument("-o", "--outcat", action="store", help="Output catalogu", required=True, metavar='outcat')
parser.add_argument("-r", "--radius", action="store", type=float, help="Radius for matching (arcsec)", required=True, metavar='radius')
parser.add_argument("-u", "--union", action="store_true", help="Operate just a  table horizontal union")
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
options = parser.parse_args()


#
if options.inputcats and options.radius:
    #
    for f in options.inputcats:
        if not os.path.isfile(f):
            parser.error ("Input file %s does not exist." % f)
    #
    if not options.union and options.radius < 0.0:
        parser.error ("Matching radius must be positive.")
    #
    racol = options.coordcols[0]
    deccol = options.coordcols[1]
    #
    fnames = []
    fnamesf = []
    fnamesw = []
    coords = []
    tabs = []
    for fl in options.inputcats:
        if options.verbose:
            print ("Reading table ", fl)
        dt = Table.read(fl, format='ascii.ecsv')
        #
        if GW.FNAMECOL in dt.columns.keys():
            fnames.append(dt[0][GW.FNAMECOL])
        else:
            parser.error("File name not found in table ", fl)
        #
        if GW.FNAMECOLF in dt.columns.keys():
            fnamesf.append(dt[0][GW.FNAMECOLF])
        else:
            parser.error("FITS file name not found in table ", fl)
        #
        if GW.FNAMECOLW in dt.columns.keys():
            fnamesw.append(dt[0][GW.FNAMECOLW])
        else:
            parser.error("Weight file name not found in table ", fl)
        #
        if racol in dt.columns.keys() and deccol in dt.columns.keys():
            ct = SkyCoord(ra=dt[racol], dec=dt[deccol], unit="deg")
        else:
            parser.error("RA,DEC coordinates not found in table ", fl)
        coords.append(ct)
        #
        tabs.append(dt)
        #
        if options.verbose:
            print ("Table %s contains %d entries." % (fl, len(ct)))
    #
    rt = coords[0]
    #
    if options.union:
        if options.verbose:
            print ("Joining tables...")
        t= hstack([tabs],metadata_conflicts='warn')
    else:
        if options.verbose:
            print ("Matching...")
        for i in range(1,len(fnames)):
            if options.verbose:
                print ("Match %s vs %s" % (options.inputcats[0],options.inputcats[i]))
            #
            id1, id2, d2d, d3d = search_around_sky(rt, coords[i], options.radius*u.arcsec)
            #
            ind1 = numpy.unique(id1)
            ind2 = numpy.unique(id2)
            mask1 = numpy.ones(len(rt), dtype=numpy.bool)
            mask2 = numpy.ones(len(coords[i]), dtype=numpy.bool)
            mask1[ind1] = False
            mask2[ind2] = False
            noid1 = numpy.arange(len(rt))[mask1]
            noid2 = numpy.arange(len(coords[i]))[mask2]
            #
            d1 = rt[id1]
            d2 = rt[noid1]
            d3 = coords[i][noid2]
            rt = SkyCoord([d1,d2,d3])
            #
        t = Table([rt.data.lon,rt.data.lat], names=(GW.RACOL,GW.DECCOL))
        for fl in enumerate(fnames):
            tc = Column([fl[1]]*len(rt),name=GW.FNAMECOL+'_'+str(fl[0]+1))
            t.add_column(tc)
            #
            tc = Column([fnamesf[fl[0]]]*len(rt),name=GW.FNAMECOLF+'_'+str(fl[0]+1))
            t.add_column(tc)
            #
            tc = Column([fnamesw[fl[0]]]*len(rt),name=GW.FNAMECOLW+'_'+str(fl[0]+1))
            t.add_column(tc)
    #
    if options.verbose:
        print ("Sorting table...")
    t.sort([racol,deccol])
    if options.verbose:
        print ("Source name...")
    cooint = []
    for ra, dec in zip(t[racol],t[deccol]):
        cooint.append("%s%.5f%+.5f" % (options.Id, ra, dec))
    sname = Column(cooint, name=GW.IDCOL)
    t.add_column(sname, index=0)
    #
    tt = unique(t,GW.IDCOL)
    #
    if options.addcols != None:
        gct = SkyCoord(ra=tt[racol], dec=tt[deccol], unit="deg")
        for c in options.addcols:
            for en in enumerate(tabs):
                if c in en[1].columns.keys():
                    id1, id2, d2d, d3d = search_around_sky(gct, coords[en[0]], options.radius*u.arcsec)
                    tt[c+'_'+str(en[0]+1)] = 99.
                    tt[c+'_'+str(en[0]+1)][id1] = en[1][c]
                else:
                    print("Columns {:s} not existent in table {:s}\n".format(c,fnames[en[0]]))
    #
    if options.verbose:
        print ("Saving table %s with %d entries" % (options.outcat,len(tt)))
    else:
        print (options.outcat,len(tt))
    tt.write(options.outcat,format='ascii.ecsv',overwrite=True)
else:
    parser.print_help()
#
