#!/usr/bin/env python

"""

    Create reference spectrum model - takes as input 2d arc spectrum image.

    If we don't find a plain text file of coords matching a row to wavelength in angstroms, 
    we display a plot so the user can make one. Then we fit and save a model.

    Copyright 2014-2018 Matt Hilton (matt.hilton@mykolab.com)
    
    This file is part of RSSMOSPipeline.

    RSSMOSPipeline is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    RSSMOSPipeline is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with RSSMOSPipeline.  If not, see <http://www.gnu.org/licenses/>.
    
"""

import os
import sys
import astropy.io.fits as pyfits
import astropy.table as atpy
#from astLib import *
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
from scipy import interpolate
from scipy import optimize
import argparse
import pickle
import IPython
#plt.matplotlib.interactive(True)

#-------------------------------------------------------------------------------------------------------------
def parseApproxArcCoordsFile(fileName):
    """Returns a dictionary indexed by x coord of line wavelengths in Angstroms.
    
    """
    
    inFile=open(fileName, "r")
    lines=inFile.readlines()
    inFile.close()
    
    coordsDict={}
    for line in lines:
        if len(line) > 3 and line[0] != "#":
            bits=line.split()
            coordsDict[int(bits[1])]=float(bits[0])
    
    return coordsDict

#-------------------------------------------------------------------------------------------------------------
def detectLines(data, sigmaCut = 3.0, thresholdSigma = 5.0, featureMinPix = 30):
    """Detect lines in a 2d arc spectrum. Uses the central row of the 2d spectrum only.
    
    Returns: featureTable, segmentationMap
    
    """
    
    # Detect arc lines
    mean=0
    sigma=1e6
    for i in range(20):
        nonZeroMask=np.not_equal(data, 0)
        mask=np.less(abs(data-mean), sigmaCut*sigma)
        mask=np.logical_and(nonZeroMask, mask)
        mean=np.mean(data[mask])
        sigma=np.std(data[mask])
    detectionThreshold=thresholdSigma*sigma
    mask=np.greater(data-mean, detectionThreshold)

    # Get feature positions, number of pixels etc.
    # Find features in 2d, match to wavelength coord in centre row
    segmentationMap, numObjects=ndimage.label(mask)
    sigPixMask=np.equal(mask, 1)
    objIDs=np.unique(segmentationMap)
    objNumPix=ndimage.sum(sigPixMask, labels = segmentationMap, index = objIDs)
    objPositions_centreRow=ndimage.center_of_mass(data[int(data.shape[0]/2)], labels = segmentationMap, index = objIDs)
    objPositions_centreRow=np.array(objPositions_centreRow).flatten()
    minPixMask=np.greater(objNumPix, featureMinPix)
    featureTable=atpy.Table()
    featureTable.add_column(atpy.Column(objIDs[minPixMask], 'id'))
    featureTable.add_column(atpy.Column(objPositions_centreRow[minPixMask], 'x_centreRow'))
    featureTable.add_column(atpy.Column([int(data.shape[0]/2)]*len(featureTable), 'y_centreRow'))
    featureTable.add_column(atpy.Column(data[int(data.shape[0]/2), np.array(np.round(featureTable['x_centreRow']), dtype = int)], 'amplitude'))
    
    return featureTable, segmentationMap

#-------------------------------------------------------------------------------------------------------------
def tagWavelengthFeatures(featureTable, approxCoordsDict):
    """This adds a wavelength column to featureTable, tagging features which are nearest in x-coord to the 
    contents of features in approxCoordsDict. This will only work on the reference arc spectrum.
    
    Removes features from featureTable which are not tagged with wavelengths.
    
    Returns featureTable
    
    """

    featureTable.add_column(atpy.Column(np.zeros(len(featureTable)), 'wavelength'))
    for x in approxCoordsDict.keys():
        wavelength=approxCoordsDict[x]
        rowNumber=np.argmin(abs(x-featureTable['x_centreRow']))
        featureTable['wavelength'][rowNumber]=wavelength

    featureTable=featureTable[np.where(featureTable['wavelength'] != 0)]
    
    return featureTable

#-------------------------------------------------------------------------------------------------------------
def makeModelArcSpectrum(data, approxCoordsDict, outFileName, yRow, sigmaCut = 3.0, thresholdSigma = 5.0, 
                         featureMinPix = 30):
    """Make reference model arc spectrum. This has wavelengths of features identified in a table. We also
    save the middle row of the spectrum.
    
    """
    
    # Detect and tag features with known wavelengths in reference spectrum
    featureTable, segMap=detectLines(data)
    featureTable=tagWavelengthFeatures(featureTable, approxCoordsDict)
    data_centreRow=data[yRow]
    
    # Save reference model as a pickled dictionary
    refModelDict={'featureTable': featureTable, 'arc_centreRow': data_centreRow}
    pickleFile=open(outFileName, "wb")
    pickler=pickle.Pickler(pickleFile)
    pickler.dump(refModelDict)
    pickleFile.close()
        
#-------------------------------------------------------------------------------------------------------------
def findWavelengthCalibration(arcFileName, modelFileName, sigmaCut = 3.0, thresholdSigma = 5.0, 
                              featureMinPix = 50, order = 6):
    """Find wavelength calibration for .fits image arcFileName containing a 2d arc spectrum.
    
    modelFileName is the path to a model made by makeModelArcSpectrum.
    
    Returns an array of polynomial fit coefficients that can be fed into wavelengthCalibrateAndRectify
    
    NOTE: CHECK ORDER, ADD STEPPIX FOR LONGSLIT
    
    """
    
    # Load reference model
    pickleFile=open(modelFileName, "rb")
    unpickler=pickle.Unpickler(pickleFile)
    refModelDict=unpickler.load()
    pickleFile.close()
    
    # First need to find arc features
    arcImg=pyfits.open(arcFileName)
    arcData=arcImg['SCI'].data
    arcFeatureTable, arcSegMap=detectLines(arcData)
    
    # Use cross correlation to get initial guess at shift between arc and model
    arc_centreRow=arcData[arcData.shape[0]/2]
    shift=np.argmax(np.correlate(refModelDict['arc_centreRow'], arc_centreRow, mode = 'full'))-refModelDict['arc_centreRow'].shape[0]
    
    # Find wavelength dependent scale change by brute force - takes ~2 sec.
    # We tried using optimize.minimize but it wasn't robust and didn't work as well as this
    # NOTE: We may need to set min/max scales somehow
    data_x=np.arange(0, arc_centreRow.shape[0])
    scales=np.linspace(-0.05, 0.05, 2000)
    overlaps=[]
    for s in scales:
        tck=interpolate.splrep((data_x+shift)+s*data_x, arc_centreRow)
        arc_centreRow_shifted=interpolate.splev(data_x, tck, ext = 1)
        overlap=np.trapz(abs(refModelDict['arc_centreRow']-arc_centreRow_shifted))
        overlaps.append(overlap)
    bestFitScale=scales[np.argmin(overlaps)]
    #plt.plot(scales, overlaps, 'r.')
    
    # Sanity check: coord transformation arc -> ref model coords
    x=np.arange(0, len(arc_centreRow))
    arc_x_shifted=(x+shift)+x*bestFitScale
    plt.plot(x, refModelDict['arc_centreRow'], 'b-')
    plt.plot(arc_x_shifted, arc_centreRow, 'r-')
    plt.close()
    
    # Tag wavelength features in arc
    arcFeatureTable.add_column('wavelength', np.zeros(len(arcFeatureTable)))
    maxDistancePix=20   # maximum error allowed in our coord transformation
    for row in refModelDict['featureTable']:
        dist=abs(row['x_centreRow']-(arcFeatureTable['x_centreRow']+shift+arcFeatureTable['x_centreRow']*bestFitScale))
        if dist.min() < maxDistancePix:
            index=np.argmin(dist)
            arcFeatureTable['wavelength'][index]=row['wavelength']
    arcFeatureTable=arcFeatureTable[np.where(arcFeatureTable['wavelength'] != 0)]
    
    # Sanity check: tagged features
    plt.plot(arcData[arcData.shape[0]/2], 'b-')
    for row in arcFeatureTable:
        plt.text(row['x_centreRow'], row['amplitude'], row['wavelength'])
    plt.close()
    #plt.show()
      
    # Find 2d wavelength solution which we can use for rectification/wavelength calibration
    # Fit functions for how feature x positions change with y
    # For longslit, we have a border with no data - here is a rough estimate of size in %-age terms
    border=int(round(arcData.shape[0]*0.05))
    stepPix=10
    ys=np.arange(border, arcData.shape[0]-border, stepPix)
    order=6
    for o in range(order+1):
        arcFeatureTable.add_column(atpy.Column(np.zeros(len(arcFeatureTable)), 'poly%d' % (o)))
    count=0
    #plt.matplotlib.interactive(True)
    for row in arcFeatureTable:
        count=count+1
        print("... %d/%d ..." % (count, len(arcFeatureTable)))
        xs=np.zeros(ys.shape)
        for i in range(len(ys)):
            objPositions=ndimage.center_of_mass(arcData[ys[i]], labels = arcSegMap, index = arcFeatureTable['id'])
            xs[i]=objPositions[np.where(arcFeatureTable['id'] == row['id'])[0]][0]
        # For long slit, this is not linear it is also, obviously, double valued!
        validMask=np.equal(np.isnan(xs), False)
        coeffs=np.polyfit(ys[validMask], xs[validMask], 6)
        poly=np.poly1d(coeffs)
        xsFit=poly(ys)
        for o in range(order+1):
            row['poly%d' % (o)]=coeffs[o]
        #plt.plot(xs[validMask], ys[validMask], 'r.')
        #plt.plot(xsFit, ys, 'k--')
        #IPython.embed()
        #sys.exit()

    # Wavelength calibration and model with arbitrary order polynomials - get coeffs for each row
    # We could potentially fit these coeffs as fn. of y - they are all well behaved
    # This array should be all we need to wavelength calibrate + rectify
    # KEEP TRACK OF ORDER HERE! NEED TO GET ALL COEFFS, BUT SUBSQUENT FIT CAN BE 2ND ORDER
    wavelengths=arcFeatureTable['wavelength']
    fitCoeffsArr=[]
    for y in range(arcData.shape[0]):
        xs=[]
        for row in arcFeatureTable:
            coeffs=[]
            for o in range(order+1):
                coeffs.append(row['poly%d' % (o)])
            poly=np.poly1d(coeffs)
            xs.append(poly(y))
        xs=np.array(xs)
        fitCoeffsArr.append(np.polyfit(xs, wavelengths, 2))
    fitCoeffsArr=np.array(fitCoeffsArr)
    
    # Sanity check: wavelength calibration model with tagged features
    wavelengthCalibPoly=np.poly1d(fitCoeffsArr[arcData.shape[0]/2])
    wavelengths=wavelengthCalibPoly(np.arange(arcData.shape[1]))
    plt.plot(wavelengths, arcData[arcData.shape[0]/2], 'r-')
    plt.plot(arcFeatureTable['wavelength'], arcFeatureTable['amplitude'], 'bo')
    for row in arcFeatureTable:
        plt.text(row['wavelength'], row['amplitude'], row['wavelength'])
    plt.close()
    
    return fitCoeffsArr
    
#-------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    usage="""
    %prog arcFrame2d.fits

    """

    parser=argparse.ArgumentParser("rss_mos_create_arc_model")
    parser.add_argument("arcFileName", help="Path to 2d arc file name")
    parser.add_argument("-y", "--y", dest="yRow", help="row from which to extract arc spectrum")
    args=parser.parse_args()
    
    #if os.path.exists(modelsDir) == False:
        #os.makedirs(modelsDir)
    
    with pyfits.open(args.arcFileName) as img:
        data=img['SCI'].data
        header=img[0].header

    if args.yRow is not None:
        row=int(args.yRow)
    else:
        row=int(data.shape[0]/2)
    
    coordsFileName=header['GRATING']+"_"+header['LAMPID']+"_"+header['CCDSUM'].replace(" ", "x")+".txt"
    modelFileName="RefModel_"+coordsFileName.replace(".txt", ".pickle")
    print("coordsFileName = %s" % (coordsFileName))
    print("modelFileName = %s" % (modelFileName))
    if os.path.exists(coordsFileName) == False:
        print("Coordinates file not found - you need to make one...")
        print("1. Use the plot window to record approximate pixel coords (x-axis) for spectral lines identified in a calibrated arc plot") 
        print("   (see http://pysalt.salt.ac.za/lineatlas/lineatlas.html and look under the heading 'Arc Lamp Plots').")
        print("2. Save a text file (name it %s) in the current directory, which has columns: Wavelengh (Angstroms, float), pixel coord (integer) on x-axis." % (coordsFileName))
        print("3. Close the plot window and re-run this script to generate the model.")
        print("4. Copy the model .pickle and .txt file to the RSSMOSPipeline/data/modelArcSpectra/ directory and re-run the setup.py script")
        plt.plot(data[row])
        plt.title(coordsFileName)
        plt.show()
        sys.exit()
    
    approxCoordsDict=parseApproxArcCoordsFile(coordsFileName)

    makeModelArcSpectrum(data, approxCoordsDict, modelFileName, row)
    
