#!/usr/bin/env python3
"""

Pointy-clicky inspection of spectra

"""
import astropy.io.fits as pyfits
import argparse
import numpy as np
import pylab as plt
import matplotlib.patches as patches
from astLib import astSED
from scipy import optimize
from scipy import ndimage
import os
import sys
import glob
import datetime
import tkinter
import RSSMOSPipeline
#import tkFont
#import IPython

# Might want to use pkg resources
TEMPLATES_DIR=RSSMOSPipeline.__path__[0]+os.path.sep+"data"+os.path.sep+"templateSpectra"

LIGHT_SPEED=299792.458

#------------------------------------------------------------------------------------------------------------
# SDSS templates to use - note later on we explicitly load in the Tremonti et al. starburst template
# which is in a different format and stored under $HOME/Astro_Software/TremontiStarburstTemplate
#TEMPLATES_DIR=os.environ['HOME']+os.path.sep+"Astro_Software"+os.path.sep+"VisualTemplateRedshiftTemplates"+os.path.sep
tremontiFileName=TEMPLATES_DIR+os.path.sep+"TremontiStarburstTemplate/fos_ghrs_composite.txt"
templateFileNames=[
# Galaxies
"spDR2-023.fit",
"spDR2-024.fit",
"spDR2-025.fit",
"spDR2-026.fit",
"spDR2-027.fit",
"spDR2-028.fit",
# QSOs
"spDR2-029.fit",
"spDR2-030.fit",
"spDR2-031.fit",
"spDR2-032.fit"
]
# Added this because Tremonti starburst template added
templateLabels=[
# Galaxies
"SDSS-023",
"SDSS-024",
"SDSS-025",
"SDSS-026",
"SDSS-027",
"SDSS-028",
# QSOs
"SDSS-029",
"SDSS-030",
"SDSS-031",
"SDSS-032",
# Starburst
"T03 Starburst"
]

# List of spectral line labels and wavelengths
# Unlike the old style scripts, these don't need unique labels/names
spectralFeaturesCatalogue=[
        ["Lyalpha", 1216.0],
        ["NV", 1240.14],
        ["SiII", 1262.59],
        ["OI", 1304.3], 
        ["SiII", 1306.82],
        ["CII", 1335.30],
        ["SiIV", 1396.76],
        ["OIV]", 1402.06],
        ["CIV", 1549.0],
        ["[NeV]", 1575.0],
        ["[NeIV]", 1602.0],
        ["OIII]", 1663.0],
        ["CIII", 1909.0],
        ["CII]", 2326.44],
        ["FeII", 2344.0],
        ["FeII", 2374.0],
        ["FeII", 2382.0],
        ["[NeIV]", 2423.83],
        ["[OII]", 2471.03],
        ["FeII", 2586.0],
        ["FeII", 2600.0],
        ["AlII]", 2669.95],
        ["OIII",  2672.04],
        ["MgII", 2798.0],
        ["MgII", 2800.0],
        ["MgII", 2803.0],
        ["MgI", 2852.0],
        ["FeI", 3230.96], # Added Dec 2015, based on http://www.physics.nist.gov/PhysRefData/ASD/lines_form.html
        ["FeI", 3370.78], # Added Dec 2015, based on http://www.physics.nist.gov/PhysRefData/ASD/lines_form.html
        ["[NeV]", 3426.0],
        ["FeI", 3582.1],
        ["[OII]", 3727.3],
        ["Htheta", 3798.6],
        ["FeI", 3820.4],
        ["MgI", 3830.4],
        ["CN", 3875.0],
        ["Hzeta", 3888.9],
        ["K", 3933.6],
        ["[NeIII]", 3967.5],
        ["H", 3968.5],
        ["Hepsilon", 3970.10],
        ["FeI", 4045.8],
        ["Hdelta", 4101.70],
        ["CaI", 4226.7],
        ["FeI", 4271.7],
        ["G", 4300.4],
        ["FeI", 4329.7],
        ["HeII", 4338.6],
        ["Hgamma",4340.50],
        ["HeI",4471.50],
        ["HeII",4685.70],
        ["Hbeta",4861.33],
        ["[OIII]",4958.91],
        ["[OIII]",5006.84],
        ["MgI", 5175.4],
        ["E",5269.00],
        ["HeI", 5875.6],
        ["NaI", 5892.5],
        ["[OI]", 6300.3],
        ["[NII]", 6548.1],
        ["Halpha", 6562.8],
        ["[NII]", 6583.4],
        ["[SII]", 6717.0],
        ["[SII]", 6731.3]
]

# For checking wavelength calibration accuracy
# Useful plot: http://www.astro.keele.ac.uk/jkt/GrSpInstructions/skylines1.jpg
checkSkyLines=[
               ##4980.7,
               ##5461.0,
               5577.0, 
               5893.0, 
               6300.304,
               6363.708,
               6863.955,
               #6923.220,
               #7276.405,
               #7316.282,
               #7340.885,
               #7750.640,
               #7794.112,
               #7821.503,
               #7913.708,
               #7993.332,
               #8344.602,
               #8399.170,
               #8430.174,
               #8761.314,
               #8767.912,
               #8778.333
               ]

#------------------------------------------------------------------------------------------------------------
class App:

    def __init__(self, master, objectSpecFileNames, outDir):
        """outDir is a dir in which to write output
        
        """

        # Layout - these get drawn top to bottom
        self.buttonFrame=tkinter.Frame(master, padx=5, pady=5)
        self.buttonFrame.grid()
        self.qualityFrame=tkinter.Frame(master, padx=5, pady=5)
        self.qualityFrame.grid()
        self.smoothFrame=tkinter.Frame(master, padx=5, pady=5)
        self.smoothFrame.grid()
        templatesFrame=tkinter.Frame(master, padx=5, pady=5)
        templatesFrame.grid()
        scaleFrame=tkinter.Frame(master, padx=5, pady=5)
        scaleFrame.grid()
        featuresFrame=tkinter.Frame(master, padx=5, pady=5)
        featuresFrame.grid()
        self.objectSpecFileNames=objectSpecFileNames
        self.currentSpecFileIndex=0
        print("Checking spectrum %d/%d ..." % (self.currentSpecFileIndex+1, len(self.objectSpecFileNames)))
        
        # Create the output text open for results
        self.outDir=outDir
        if os.path.exists(outDir) == False:
            os.makedirs(outDir)
        #self.outFile=open(outDir+os.path.sep+datetime.datetime.today().isoformat()+".results", "w")
        self.outFile=open(outDir+os.path.sep+"results.txt", "w")
        self.outFile.write("# Filename\tRedshift\tRedshiftError\tQuality\tComments\n")
        
        # Load data we need
        self.templates=self.loadTemplates(TEMPLATES_DIR, templateFileNames)
        
        obj=self.loadObjectSpectrum(objectSpecFileNames[self.currentSpecFileIndex])
        self.objectSpecFileName=objectSpecFileNames[self.currentSpecFileIndex]  # for plot titles etc.
        self.objSED=obj['object']
        self.unsmoothedObjFlux=self.objSED.flux[:]
        self.skySED=obj['sky']
        self.badLines=[]
        
        # List of feature names in spectralFeaturesCatalogue to plot
        self.plotFeatures=[]

        # Buttons frame
        self.nextButton=tkinter.Button(self.buttonFrame, text="Done - show next", fg="blue", command=self.next)
        self.nextButton.grid(row=0, column=6)
        self.quitButton=tkinter.Button(self.buttonFrame, text="QUIT", fg="red", command=self.buttonFrame.quit)
        self.quitButton.grid(row=0, column=7)

        self.savePNGButton=tkinter.Button(self.buttonFrame, text="Save .png", command=self.savePNG)
        self.savePNGButton.grid(row=0, column=1)
        
        self.redrawButton=tkinter.Button(self.buttonFrame, text="Redraw plot", command=self.redrawPlot)
        self.redrawButton.grid(row=0, column=0)
        
        self.outPathLabel=tkinter.Label(self.buttonFrame, text="Output .png open : ")
        self.outPathLabel.grid(row=0, column=2)

        self.outPathEntryVar=tkinter.StringVar()
        self.outPathEntryVar.set(self.outDir+os.path.sep+os.path.split(self.objectSpecFileName)[-1].replace(".fits", ".png"))
        self.outPathEntry=tkinter.Entry(self.buttonFrame, textvariable=self.outPathEntryVar, width=80)
        self.outPathEntry.grid(row=0, column=3)
        
        # Quality frame, contains radio buttons and comments field
        self.qualityRadioVar=tkinter.IntVar()
        self.qualityRadioList=[]
        self.qualityLabel=tkinter.Label(self.qualityFrame, text="Quality flag : ", anchor=tkinter.E)
        self.qualityLabel.grid(row=2, column=0)
        for i in range(4):
            self.qualityRadioList.append(tkinter.Radiobutton(self.qualityFrame, text = str(i),
                                            variable=self.qualityRadioVar, value = i, command = self.redrawPlot))
            self.qualityRadioList[-1].grid(row=2, column=i+1)
        self.qualityRadioList[0].select()

        self.commentsLabel=tkinter.Label(self.qualityFrame, text="Comments: ")
        self.commentsLabel.grid(row=2, column=i+2)
        self.commentsEntryVar=tkinter.StringVar()
        self.commentsEntryVar.set("")
        self.commentsEntry=tkinter.Entry(self.qualityFrame, textvariable=self.commentsEntryVar, width=80)
        self.commentsEntry.grid(row=2, column=i+3)
        
        # Slider to set smoothing of object spectrum
        self.smoothScaleVar=tkinter.DoubleVar()
        self.smoothLabel=tkinter.Label(self.smoothFrame, text="Spectrum smoothing : ", anchor=tkinter.E)
        self.smoothLabel.grid(row=1, column=0)
        self.smoothScale = tkinter.Scale(self.smoothFrame, orient=tkinter.HORIZONTAL, length = 300, \
            from_=0, to=100, tickinterval=25, \
            command = self.smoothSpectrum, variable=self.smoothScaleVar, resolution=1)
        self.smoothScale.set(5)
        x = self.smoothScale.get()
        self.objSED.flux=self.unsmoothedObjFlux[:]
        self.objSED.smooth(x)
        self.smoothScale.grid(row=1, column=1, columnspan=3)
        
        # Buttons to finely tune the smoothing
        self.smoothPlusButton=tkinter.Button(self.smoothFrame, text="+", command=self.increaseSmoothing)
        self.smoothPlusButton.grid(row=1, column=5)
        self.smoothMinusButton=tkinter.Button(self.smoothFrame, text="-", command=self.decreaseSmoothing)
        self.smoothMinusButton.grid(row=1, column=4)
        
        # Min, max wavelength range
        self.minWavelengthLabel=tkinter.Label(self.smoothFrame, text="Min WL: ", anchor=tkinter.E)
        self.minWavelengthLabel.grid(row=1, column=6)
        self.minWavelengthEntryVar=tkinter.StringVar()
        self.minWavelengthEntryVar.set(self.objSED.wavelength.min())
        self.minWavelengthEntry=tkinter.Entry(self.smoothFrame, textvariable=self.minWavelengthEntryVar, width=6)
        self.minWavelengthEntry.grid(row=1, column=7)
        
        self.maxWavelengthLabel=tkinter.Label(self.smoothFrame, text="Max WL: ", anchor=tkinter.E)
        self.maxWavelengthLabel.grid(row=1, column=8)
        self.maxWavelengthEntryVar=tkinter.StringVar()
        self.maxWavelengthEntryVar.set(self.objSED.wavelength.max())
        self.maxWavelengthEntry=tkinter.Entry(self.smoothFrame, textvariable=self.maxWavelengthEntryVar, width=6)
        self.maxWavelengthEntry.grid(row=1, column=9)
        
        # Alt normalisation method
        self.altNormCheckVar=tkinter.IntVar()
        self.altNormLabel=tkinter.Label(self.smoothFrame, text="Alt norm", width=10, anchor=tkinter.E)
        self.altNormLabel.grid(row=1, column=10)
        self.altNormCheckButton=tkinter.Checkbutton(self.smoothFrame,variable=self.altNormCheckVar, command=self.plotSkyChanged)           
        self.altNormCheckButton.grid(row=1, column=11)
        
        # Turn sky plotting on/off
        self.plotSkyCheckVar=tkinter.IntVar()
        self.plotSkyLabel=tkinter.Label(self.smoothFrame, text="Plot sky", width=10, anchor=tkinter.E)
        self.plotSkyLabel.grid(row=1, column=12)
        self.plotSkyCheckButton=tkinter.Checkbutton(self.smoothFrame,variable=self.plotSkyCheckVar, command=self.plotSkyChanged)           
        self.plotSkyCheckButton.grid(row=1, column=13)

        # Templates frame
        # Radio buttons used to select template
        self.templateRadioVar=tkinter.IntVar()
        self.templateRadioList=[]
        self.templateLabel=tkinter.Label(templatesFrame, text="Template:", anchor=tkinter.E)
        self.templateLabel.grid(row=2, column=0)
        
        for i in range(len(self.templates)):
            try:
                tempName=templateLabels[i]
            except:
                ipshell()
                sys.exit()
            self.templateRadioList.append(tkinter.Radiobutton(templatesFrame, text = tempName,
                                            variable=self.templateRadioVar, value = i, command = self.redrawPlot))
            self.templateRadioList[-1].grid(row=2, column=i+1)
        self.templateRadioList[0].select()
    
        # Slider used to set trial redshift of template
        self.redshiftScaleVar=tkinter.DoubleVar()
        self.redshiftLabel=tkinter.Label(scaleFrame, text="Template redshift : ", anchor=tkinter.E)
        self.redshiftLabel.grid(row=3, column=0)
        self.redshiftScale = tkinter.Scale(scaleFrame, orient=tkinter.HORIZONTAL, length = 600, \
            from_=0.0, to=5.01, tickinterval=1, \
            command = self.getRedshiftScaleValue, variable=self.redshiftScaleVar, resolution=0.001)
        self.redshiftScale.set(0.5)
        self.redshiftScale.grid(row=3, column=1, columnspan=3)
        
        # Buttons to finely tune the trial redshift
        self.redshiftPlusButton=tkinter.Button(scaleFrame, text="+", command=self.increaseRedshift)
        self.redshiftPlusButton.grid(row=3, column=5)
        self.redshiftMinusButton=tkinter.Button(scaleFrame, text="-", command=self.decreaseRedshift)
        self.redshiftMinusButton.grid(row=3, column=4)
        
        # Redshift uncertainty entry box
        self.redshiftErrorLabel=tkinter.Label(scaleFrame, text="+/-")
        self.redshiftErrorLabel.grid(row=3, column=6)
        self.redshiftErrorEntryVar=tkinter.StringVar()
        self.redshiftErrorEntryVar.set(0.001)
        self.redshiftErrorEntry=tkinter.Entry(scaleFrame, textvariable=self.redshiftErrorEntryVar, width=10)
        self.redshiftErrorEntry.grid(row=3, column=7)
        
        # XCSAO buttons
        self.runXCSAOButton=tkinter.Button(scaleFrame, text="XC Galaxies", command=self.runXCSAOGalaxies, fg="green")
        self.runXCSAOButton.grid(row=3, column=8)
        self.runXCSAOLRGsButton=tkinter.Button(scaleFrame, text="XC LRGs", command=self.runXCSAOLRGs, fg="green")
        self.runXCSAOLRGsButton.grid(row=3, column=9)
        self.runXCSAOQSOsButton=tkinter.Button(scaleFrame, text="XC QSOs", command=self.runXCSAOQSOs, fg="green")
        self.runXCSAOQSOsButton.grid(row=3, column=10)
        
        # Features to optionally plot
        # Have to have separate labels for the check boxes because other layout goes stupid
        self.featuresCheckList=[]
        self.featuresCheckLabelsList=[]
        self.featuresCheckVars=[]
        maxPerRow=20
        row=0
        column=0
        for i in range(len(spectralFeaturesCatalogue)):
            if column == maxPerRow:
                row=row+1
                column=0
            self.featuresCheckVars.append(tkinter.IntVar())
            self.featuresCheckLabelsList.append(tkinter.Label(featuresFrame, text=spectralFeaturesCatalogue[i][0], width=10, anchor=tkinter.E))
            self.featuresCheckLabelsList[-1].grid(row=row+4, column=column)
            self.featuresCheckList.append(tkinter.Checkbutton(featuresFrame, 
                                            variable=self.featuresCheckVars[-1], command=self.setPlotFeatures))
            self.featuresCheckList[-1].grid(row=row+4, column=column+1)
            column=column+2
        
        # Start up the figure for drawing
        plt.figure(figsize=(12,8))
        
        # Do initial plot
        self.updatePlot(self.objSED, self.templates[self.templateRadioVar.get()], self.skySED,
                        self.redshiftScaleVar.get(), 
                        tempLabel=os.path.split(templateLabels[self.templateRadioVar.get()])[-1], 
                        redrawSky = True, redrawFeatures=True, plotFeatures=self.plotFeatures)
    
    def getRedshiftScaleValue(self, event):
        x = self.redshiftScale.get()
        #self.lbl.configure(text = x)
        #self.updatePlot(self.objSED, self.templates[self.rbVar.get()], self.skySED, self.scaleVar.get())
    
    def increaseRedshift(self):
        self.redshiftScale.set(self.redshiftScaleVar.get()+0.001)

    def decreaseRedshift(self):
        self.redshiftScale.set(self.redshiftScaleVar.get()-0.001)

    def runXCSAOGalaxies(self):
        templatesToInclude=[]
        for temp in templateFileNames:
            tempNum=int(os.path.split(temp)[-1].split("-")[-1].split(".")[0])
            if tempNum >= 23 and tempNum < 29:
                templatesToInclude.append(temp)
        self.runXCSAO(templatesToInclude = templatesToInclude, s_emchop = "n")

    def runXCSAOLRGs(self):
        templatesToInclude=[]
        for temp in templateFileNames:
            tempNum=int(os.path.split(temp)[-1].split("-")[-1].split(".")[0])
            #if tempNum >= 23 and tempNum < 29:
            #if tempNum == 23 or tempNum == 28:
            if tempNum == 28:
                templatesToInclude.append(temp)
        self.runXCSAO(templatesToInclude = templatesToInclude, s_emchop = "y")
        
    def runXCSAOQSOs(self):
        templatesToInclude=[]
        for temp in templateFileNames:
            tempNum=int(os.path.split(temp)[-1].split("-")[-1].split(".")[0])
            if tempNum >= 29 and tempNum < 33:
                templatesToInclude.append(temp)
        self.runXCSAO(templatesToInclude = templatesToInclude, s_emchop = "n")
    
    def runXCSAO(self, templatesToInclude = [], s_emchop = "y"):
        from pyraf import iraf
        from iraf import rvsao
        
        xMin=float(self.minWavelengthEntry.get())
        xMax=float(self.maxWavelengthEntry.get())
        
        # This has to have the open available in IRAF friendly format
        #if os.path.exists("spec1d_IRAF") == False:
            #os.makedirs("spec1d_IRAF")
        irafFileName="toiraf.fits"#"spec1d_IRAF"+os.path.sep+"iraf_boxcar_"+self.objectFileName
        print(irafFileName)
        result=None
        if os.path.exists(irafFileName) == True or self.convertToIRAFFormat() == True:
            print("--> cross correlating "+irafFileName+" ...")
            
            # Mask prominent sky emission lines
            if self.skySED == None:
                fixbad="n"
            else:
                fixbad="y"
            if fixbad == "y":
                #print "fixing bad lines"
                #print "find chip gaps etc."
                #IPython.embed()
                #sys.exit()
                # Old sky emission-line masking
                #normSkyFlux=self.skySED.flux/self.skySED.flux.max()
                #threshold=0.1
                #badPix=np.where(normSkyFlux > threshold)[0]
                #for i in range(len(badPix)):
                    #startPixInLine=False
                    #for line in lines:
                        #if badPix[i]>=line[0] and badPix[i]<=line[1]:
                            #startPixInLine=True
                    #if startPixInLine==False:
                        #pixCount=1      
                        #if pixCount+i<len(badPix)-1:            
                            #nextPix=badPix[i+pixCount]
                            #prevPix=badPix[i]
                            #maxReached=False
                            #while nextPix<prevPix+2:
                                #if pixCount+i<len(badPix)-1:
                                    #prevPix=badPix[i+pixCount]
                                    #pixCount=pixCount+1
                                    #nextPix=badPix[i+pixCount]
                                #else:
                                    #maxReached=True
                                    #break
                            #if maxReached==False:
                                #lastPix=prevPix
                            #else:
                                #lastPix=max(badPix)
                        #else:
                            #lastPix=max(badPix)
                        #lines.append([badPix[i],lastPix])
                # Write output
                # Keep it around also so we can plot
                self.badLines=self.gapsList
                print(self.gapsList)
                outFile=open("badlines.dat", "w")
                #for line in lines:   
                    #self.badLines.append([self.skySED.wavelength[line[0]], self.skySED.wavelength[line[1]]])
                    #outFile.write(str(self.skySED.wavelength[line[0]])+"\t"+str(self.skySED.wavelength[line[1]])+"\n")
                # Affected by atmospheric absorption
                skyLines=[]
                #skyLines.append([6860., 6930.])
                #skyLines.append([7590., 7710.])
                for line in skyLines:
                    outFile.write(str(line[0])+"\t"+str(line[1])+"\n")
                    self.badLines.append(line)
                ## Chip gaps, found earlier
                #for gap in self.gapsList:
                    #outFile.write(str(gap[0])+"\t"+str(gap[1])+"\n")  
                #self.badLines=self.badLines+self.gapsList
                outFile.close()
                # 
            
            # Cross correlate with SDSS galaxy templates
            #for zStep in range(0,9):
                #z=0.2*zStep
            for temp in templatesToInclude:
                # old
                #rvsao.xcsao(spectra=irafFileName, tempdir=TEMPLATES_DIR, fixbad=, badlines="badlines.dat", vel_init="zguess", czguess=self.redshiftScaleVar.get(), templates=os.path.split(temp)[-1], st_lambda=xMin, end_lambda=xMax, zeropad="y", nsmooth=30, s_emchop="n", t_emchop="n", nzpass=5, minvel=(self.redshiftScaleVar.get()-0.3)*3e5, maxvel=(self.redshiftScaleVar.get()+0.3)*3e5, renormalize="y", ncols=8192, low_bin=10, top_low=20, top_nrun=250, nrun=500, bell_window=0.05, dispmode=2, curmode="n", ablines="ablines.dat", displot="no", logopens="xcsao.log", save_vel="n",  pkfrac=0.5, report_mode=1)
                # trying to optimise on SALT data
                #IPython.embed()
                #sys.exit()
                #rvsao.xcsao(spectra=irafFileName, tempdir=TEMPLATES_DIR, fixbad="y",
                            #badlines=os.path.abspath(os.path.curdir)+os.path.sep+"badlines.dat",
                            #vel_init="zguess", czguess=self.redshiftScaleVar.get(), templates=os.path.split(temp)[-1], st_lambda=xMin, end_lambda=xMax, zeropad="y", nsmooth=30, s_emchop=s_emchop, t_emchop="n", nzpass=20, minvel=(self.redshiftScaleVar.get()-0.2)*LIGHT_SPEED, maxvel=(self.redshiftScaleVar.get()+0.2)*LIGHT_SPEED, renormalize="y", ncols=4096, low_bin=5, top_low=10, top_nrun=4096, nrun=4096, bell_window=0.05, dispmode=2, curmode="n", ablines="ablines.dat", displot="no", logfiles="xcsao.log", save_vel="n",  pkfrac=0.8, report_mode=1)
                #result=self.parseXCSAOResult()
                rvsao.xcsao(spectra=irafFileName, tempdir=TEMPLATES_DIR, fixbad="y",
                            badlines=os.path.abspath(os.path.curdir)+os.path.sep+"badlines.dat",
                            vel_init="zguess", czguess=self.redshiftScaleVar.get(), templates=os.path.split(temp)[-1], st_lambda=xMin, end_lambda=xMax, zeropad="y", nsmooth=1, s_emchop=s_emchop, t_emchop=s_emchop, nzpass=10, minvel=(self.redshiftScaleVar.get()-0.05)*LIGHT_SPEED, maxvel=(self.redshiftScaleVar.get()+0.05)*LIGHT_SPEED, renormalize="y", ncols=4096, low_bin=5, top_low=10, top_nrun=380, nrun=400, bell_window=0.05, dispmode=2, curmode="n", ablines="ablines.dat", displot="no", logfiles="xcsao.log", save_vel="n",  pkfrac=0.8, report_mode=1, interp_mode="sums")
                #print self.parseXCSAOResult()
                #os.remove("xcsao.log")
                #IPython.embed()
                #sys.exit()
            result=self.parseXCSAOResult()
            os.remove("xcsao.log")
            
        if result != None:
            self.redshiftScaleVar.set(result['z'])
            self.commentsEntryVar.set("XCSAO (R=%.3f): " % (result['RVal']))
            self.commentsEntry['textvariable']=self.commentsEntryVar
            self.templateRadioVar.set(templateFileNames.index(result['template']))
            self.templateRadioList[templateFileNames.index(result['template'])].select()
            self.commentsEntry.update()
            self.redshiftErrorEntryVar.set("%.6f" % (result['zErr']))
            print(result)
            #ipshell()
            #sys.exit()
        else:
            print("XCSAO failed.")
        #ipshell()
        #sys.exit()
        #del sys.modules['onedspec']
        #del sys.modules['rvsao']
        #del sys.modules['pyraf']
        #del sys.modules['pyraf.iraf']
        
        self.redrawPlot()


    def convertToIRAFFormat(self):
        """Canned convert self.objectFileName spec1d IDL open into IRAF format, stored under 'spec1d_IRAF' 
        dir. Needed for running XCSAO. Note if values < 1 here, we multiply by ridiculous factor to
        stop xcsao from crashing
        
        """
        from pyraf import iraf
        from iraf import onedspec
        method="boxcar"
        baseName="spec1d_IRAF"+os.path.sep+"iraf_"+method+"_"
        print("--> Converting "+self.objectFileName+" to IRAF format ...")
        idlfits=pyfits.open(self.objectFileName)
        if method=="boxcar":
            tabExts=[1, 2]
        elif method=="optimal":
            tabExts=[3, 4]
        outFileName="toiraf.csv"

        writer=open(outFileName, "wb")
        skyWriter=open(outFileName.replace("iraf_", "sky_iraf_"), "wb")
        for tabExt in tabExts:
            # Sometimes things just don't work as they should ...
            dataOkay=True
            try:
                if len(idlfits[tabExt].data.field('SPEC').shape) > 1:
                    fluxData=idlfits[tabExt].data.field('SPEC')[0]
                else:
                    fluxData=idlfits[tabExt].data.field('SPEC')
            except IndexError:
                dataOkay=False
            if dataOkay == True:
                if len(idlfits[tabExt].data.field('LAMBDA').shape) > 1:
                    wavelengthData=idlfits[tabExt].data.field('LAMBDA')[0]
                else:
                    wavelengthData=idlfits[tabExt].data.field('LAMBDA')
                #skyData=idlfits[tabExt].data.field('SKYSPEC')[0]
                #ipshell()
                #sys.exit()
                # xcsao crash preventing when running on fluxed spectra (it should do this itself, of course)
                if fluxData.mean() < 100:
                    fluxData=(fluxData/fluxData.mean())*1e6
                minWavelength=min(wavelengthData)
                maxWavelength=max(wavelengthData)
                plotData=[]
                skyPlotData=[]
                for i in range(len(fluxData)):
                    writer.write(str(wavelengthData[i]).encode()+"\t".encode()+str(fluxData[i]).encode()+"\n".encode())
                    #skyWriter.write(str(wavelengthData[i])+"\t"+str(skyData[i])+"\n")
                    plotData.append([wavelengthData[i],fluxData[i]])
                    #skyPlotData.append([wavelengthData[i],skyData[i]])
        writer.close()
        skyWriter.close()
        if os.path.exists(outFileName.replace(".csv", ".fits")) == False:
            # This might fall over below if there's a dodgy open
            try:
                onedspec.rspectext(input=outFileName, output=outFileName.replace(".csv", ".fits"), flux="no", dtype="interp")
                #onedspec.rspectext(input=outFileName.replace("iraf_", "sky_iraf_"), output=outFileName.replace("iraf_", "sky_iraf_").rstrip("csv")+"fits", flux="no", dtype="interp")
                #del sys.modules['pyraf']
                #del sys.modules['pyraf.iraf']
                return True
            except:
                print("... Argh! there's a problem with this open that causes IRAF to crash! ...")
                print("... skipping ...")
                #del sys.modules['pyraf']
                #del sys.modules['pyraf.iraf']
                return False
    
    def parseXCSAOResult(self):
        """Parses xcsao results log open, returns highest R value result
        
        """
        with open("xcsao.log", "r") as inFile:
            lines=inFile.readlines()
        results=[]
        objectList=[]
        currentObject=""
        bestRVal=0.0
        bestResult=None
        for line in lines:
            # Check for spectrum name change
            objectChanged=False
            if line.find("Object:")!=-1:
                newObject=line[line.find("Object:")+8:].rstrip(" \n")
                if newObject!=currentObject:
                    currentObject=newObject
                    objectChanged=True               
            if objectChanged==True:
                objectList.append(currentObject)
            # Extract results -- if we asked to ignore any templates it's done here
            if line.find("CZ:") != -1:
                bits=line.split()
                result={}
                for i in range(len(bits)):
                    if bits[i]=="Temp:":
                        result['template']=str(bits[i+1])
                    if bits[i]=="R:":
                        result['RVal']=float(bits[i+1])
                    if bits[i]=="CZ:":
                        result['cz']=float(bits[i+1])
                    if bits[i]=="+/-":
                        result['czErr']=float(bits[i+1])
                if result['RVal'] > bestRVal:
                    result['z']=result['cz']/LIGHT_SPEED
                    result['zErr']=(result['czErr']/result['cz'])*(result['cz']/LIGHT_SPEED)
                    bestResult=result
                    bestRVal=result['RVal']
        return bestResult
        
    def setPlotFeatures(self):
        """Sets self.plotFeatues and triggers redrawing of plot, according to which spectral features
        are selected. Redraws plot afterwards.
        
        """
        self.plotFeatures=[]
        for i in range(len(self.featuresCheckVars)):
            val=self.featuresCheckVars[i].get()
            self.plotFeatures.append(val)
        #self.redrawPlot()
        
    def savePNG(self):
        plt.savefig(self.outPathEntry.get())
        
    def loadTemplates(self, TEMPLATES_DIR, templateFileNamesList):
        """Takes in a list of SDSS template open names. Appends the Tremonti starburst template which
        we handle differently.
         
        Returns a list containing astSED.SED objects
        
        """
        
        print("Loading templates ...")
        
        templatesList=[]
        for t in templateFileNamesList:
            
            # This loads in an SDSS spectral template, and feeds it into a SED object
            timg=pyfits.open(TEMPLATES_DIR+os.path.sep+t)
            th=timg[0].header
            tc0=th['COEFF0']
            tc1=th['COEFF1']
            tpixRange=np.arange(timg[0].data.shape[1])
            twavelengthRange=10.0**(tc0+tc1*tpixRange)
            
            tempSED=astSED.SED(wavelength = twavelengthRange, flux = timg[0].data[0])
            templatesList.append(tempSED)
        
        # Tremonti star burst
        s=astSED.SED()
        s.loadFromFile(tremontiFileName)
        templatesList.append(s)
        
        return templatesList
        
    def next(self):
        
        # Save results for current spectrum
        self.savePNG()
        if self.qualityRadioVar.get() != 0:
            self.outFile.write("%s\t%.5f\t%.5f\t%d\t%s\n" % (self.objectSpecFileNames[self.currentSpecFileIndex], \
                                            self.redshiftScaleVar.get(), float(self.redshiftErrorEntryVar.get()), self.qualityRadioVar.get(), \
                                            self.commentsEntryVar.get()))
        else:
            self.outFile.write("%s\t%s\t%s\t%d\t%s\n" % (self.objectSpecFileNames[self.currentSpecFileIndex], \
                                            "None", "None", self.qualityRadioVar.get(), \
                                            self.commentsEntryVar.get()))
        # Move on to next spectrum
        if self.currentSpecFileIndex < len(self.objectSpecFileNames)-1:
            self.currentSpecFileIndex=self.currentSpecFileIndex+1

            print("Checking spectrum %d/%d ..." % (self.currentSpecFileIndex+1, len(self.objectSpecFileNames)))

            obj=self.loadObjectSpectrum(self.objectSpecFileNames[self.currentSpecFileIndex])
            self.objectSpecFileName=self.objectSpecFileNames[self.currentSpecFileIndex]  # for plot titles etc.
            self.objSED=obj['object']
            self.unsmoothedObjFlux=self.objSED.flux[:]
            self.skySED=obj['sky']
            
            x = self.smoothScale.get()
            self.objSED.flux=self.unsmoothedObjFlux[:]
            self.objSED.smooth(x)
        
            self.resetFeatures()
            
            self.qualityRadioVar.set(0)
        
            self.commentsEntryVar.set("")
            self.outPathEntryVar.set(self.outDir+os.path.sep+os.path.split(self.objectSpecFileName)[-1].replace(".fits", ".png"))
            self.redshiftErrorEntryVar.set(0.001)

            self.minWavelengthEntryVar.set(self.objSED.wavelength.min())
            self.maxWavelengthEntryVar.set(self.objSED.wavelength.max())

            self.updatePlot(self.objSED, self.templates[self.templateRadioVar.get()], self.skySED,
                            self.redshiftScaleVar.get(), 
                            tempLabel=os.path.split(templateLabels[self.templateRadioVar.get()])[-1], 
                            redrawSky = True, redrawFeatures=True, plotFeatures=self.plotFeatures)
        else:
            print("Finished checking all spectra!")
            self.buttonFrame.quit()
            
                        
    def loadObjectSpectrum(self, objectFileName, smoothPix = 30):
        """Loads in an object spectrum - this has to be in DEEP2 pipeline spec1d format (i.e. fits tables)
        Object spectrum is smoothed by boxcar of size smoothPix.
        
        Returns a dictionary containing object and sky astSED.SED objects {'object', 'sky'}
        
        """
        
        print("Loading object spectrum ...")
        
        self.objectFileName=objectFileName
        oimg=pyfits.open(objectFileName)
        
        # If DEEP2 format, concatenate red, blue spectra
        # Otherwise, assume in efosc2 reducer format
        try:
            rwavelengthRange=oimg['HORNE-R'].data.field('LAMBDA')[0]
            rflux=oimg['HORNE-R'].data.field('SPEC')[0]
            rskyflux=oimg['HORNE-R'].data.field('SKYSPEC')[0]
            bwavelengthRange=oimg['HORNE-B'].data.field('LAMBDA')[0]
            bflux=oimg['HORNE-B'].data.field('SPEC')[0]
            bskyflux=oimg['HORNE-B'].data.field('SKYSPEC')[0]
            owavelengthRange=np.array(bwavelengthRange.tolist()+rwavelengthRange.tolist())
            oflux=np.array(bflux.tolist()+rflux.tolist())
            oskyflux=np.array(bskyflux.tolist()+rskyflux.tolist())
        except:
            owavelengthRange=oimg['1D_SPECTRUM'].data.field('LAMBDA')
            oflux=oimg['1D_SPECTRUM'].data.field('SPEC')
            columnNames=oimg['1D_SPECTRUM'].columns.names
            if 'SKYSPEC' in columnNames:
                oskyflux=oimg['1D_SPECTRUM'].data.field('SKYSPEC')
            else:
                oskyflux=None
        
        # Remove -ve values... helps find chip gaps and get xcsao to work more easily
        oflux[np.less(oflux, 0)]=0.
        
        # If there is a chip gap mask given...
        if 'MASK' in oimg['1D_SPECTRUM'].data.columns.names:
            chipGapMask=oimg['1D_SPECTRUM'].data.field('MASK')
        else:
            chipGapMask=np.zeros(oflux.shape[0])

        # Find the chip gaps and make a mask - put this in its own routine...
        wavelengths=owavelengthRange
        try:
            segmentationMap, numObjects=ndimage.label(chipGapMask)
        except:
            segmentationMap, numObjects=ndimage.label(chipGapMask.newbyteorder())            
        sigPixMask=np.equal(chipGapMask, 1)
        objIDs=np.unique(segmentationMap)
        objNumPix=ndimage.sum(sigPixMask, labels = segmentationMap, index = objIDs)
        gapsList=[]
        for objID, nPix in zip(objIDs, objNumPix):    
            if nPix > 10:
                indices=np.where(segmentationMap == objID)[0]
                gapsList.append([wavelengths[min(indices)], wavelengths[max(indices)]])
        self.gapsList=gapsList
        
        # Mask out extreme values in spectrum
        # Just because edges dodgy in efosc
        #med=np.median(oflux)
        #oflux[np.greater(abs(oflux), 10.0*med)]=0.0
        
        objSED=astSED.SED(wavelength = owavelengthRange, flux = oflux)
        objSED.flux=objSED.flux-objSED.flux.min()   #  make it > 0 everywhere
        objSED.flux=objSED.flux/objSED.flux.max()
        if np.all(oskyflux) != None:
            skySED=astSED.SED(wavelength = owavelengthRange, flux = oskyflux)
        else:
            skySED=None
        
        return {'object': objSED, 'sky': skySED}
   
    def smoothSpectrum(self, event):
        """Smooths the object spectrum when the smooth slider is updated
        
        """
        x = self.smoothScale.get()
        self.objSED.flux=self.unsmoothedObjFlux[:]
        self.objSED.smooth(x)

    def increaseSmoothing(self):
        self.smoothScale.set(self.smoothScaleVar.get()+1)

    def decreaseSmoothing(self):
        self.smoothScale.set(self.smoothScaleVar.get()-1)

    def resetFeatures(self):
        for i in range(len(self.plotFeatures)):
            self.featuresCheckVars[i].set(0)
        self.setPlotFeatures()
        
    def plotSkyChanged(self):
        """Clears figure for if we're redrawing sky subplot or not. Use if we change norm method too.
        
        """
        plt.clf()
        self.redrawPlot()
        
    def updatePlot(self, objSED, tempSED, skySED, redshift, tempLabel = None, redrawSky = True, 
                        redrawFeatures = False, plotFeatures = []):
        """Updates the plt plot of the object spectrum with template overlaid.
        
        """
                
        xMin=float(self.minWavelengthEntry.get())
        xMax=float(self.maxWavelengthEntry.get())
        
        tempSED.redshift(redshift)
        tempSED.flux=tempSED.flux
        
        plusMask=np.greater(tempSED.wavelength, xMin)
        minusMask=np.less(tempSED.wavelength, xMax)
        mask=np.logical_and(plusMask, minusMask)
        tempSED.flux=tempSED.flux-tempSED.flux.min()
        tempSED.flux=tempSED.flux/tempSED.flux[mask].max()
        
        # In case we don't want to see redward of 10000 Angstroms
        plusMask=np.greater(objSED.wavelength, xMin)
        minusMask=np.less(objSED.wavelength, xMax)
        mask=np.logical_and(plusMask, minusMask)
        
        objSED.flux=objSED.flux-objSED.flux[mask].min()
        objSED.flux=objSED.flux/objSED.flux[mask].max()
        
        # Norm based on matching flux as closely as possible between template and object spectrum
        # Ignore XX% of each end of spectrum as edges can be weird
        if self.altNormCheckVar.get() == 1:
            ignoreAngstroms=0
            dw=100
            binEdges=np.arange(xMin+ignoreAngstroms, xMax-ignoreAngstroms, dw)
            passbands=[]
            for b in binEdges:
                passbands.append(astSED.TopHatPassband(b, b+dw))
            objSEDDict=objSED.getSEDDict(passbands)
            tempSEDDict=tempSED.getSEDDict(passbands)
            norm0=0.9
            norm, success=optimize.leastsq(fitSEDNormErrFunc, norm0, args = (tempSEDDict['flux'], objSEDDict['flux']))
            objSED.flux=objSED.flux/norm
                            
        #if skySED != None and self.plotSkyCheckVar.get() == 1:
            #plt.subplot(211)
        plt.cla()
        plt.title(self.objectSpecFileName)
        plt.plot(objSED.wavelength, objSED.flux, 'k-')
        plt.plot(tempSED.wavelength, tempSED.flux, 'r-', label=tempLabel)
        plt.text(0.05, 0.92, u"z = %.5f ± %.5f (Q = %s)" % (tempSED.z, float(self.redshiftErrorEntryVar.get()), self.qualityRadioVar.get()), ha='left', va='top', transform=plt.gca().transAxes, size=12, color='r')
        plt.ylim(0, 1.2)
        #if skySED != None and self.plotSkyCheckVar.get() == True:
            #plt.gca().set_xticklabels([])
        #else:
        plt.xlabel("Wavelength (Angstroms)")

        # Plots the spectral features in turn
        #plotFeatures=["H", "K", "[OII]"]
        if redrawFeatures == True:
            #ylim=plt.gca().get_ylim() # Need this to automatically draw correct length -- for features
            for on, item in zip(self.plotFeatures, spectralFeaturesCatalogue):
                if on == 1:
                    featureLabel=item[0]
                    # Greek letters? eta will cause a problem here!
                    featureLabel=featureLabel.replace("alpha", "$\\alpha$")
                    featureLabel=featureLabel.replace("beta", "$\\beta$")
                    featureLabel=featureLabel.replace("gamma", "$\gamma$")
                    featureLabel=featureLabel.replace("delta", "$\delta$")
                    featureLabel=featureLabel.replace("epsilon", "$\\epsilon$")
                    featureLabel=featureLabel.replace("zeta", "$\zeta$")
                    featureLabel=featureLabel.replace("theta", "$\\theta$")
                    featureLambda=(1.0+float(redshift))*item[1]
                    plt.plot((featureLambda,featureLambda), (0, 1.0), 'g--')
                    plt.text(featureLambda, 1.05, featureLabel, 
                                ha='center', va='top', size=10, rotation='vertical')
                
        if redrawSky == True and self.plotSkyCheckVar.get() == 1:
            if skySED != None:
                plt.plot(skySED.wavelength, skySED.flux/skySED.flux.max()*0.3, 'b-', label='Sky')
            # Main telluric absorption features
            c=patches.Rectangle((6860, 0), (6930-6860), 1.2, fill=True, edgecolor=(0.8, 0.8, 0.8), 
                            facecolor=(0.8, 0.8, 0.8), linewidth=1)
            plt.gca().add_patch(c)
            c=patches.Rectangle((7590, 0), (7710-7590), 1.2, fill=True, edgecolor=(0.8, 0.8, 0.8), 
                            facecolor=(0.8, 0.8, 0.8), linewidth=1)
            plt.gca().add_patch(c)
            
            # Plot sky line locations, for checking wavelength calib (should be its own option)
            for l in checkSkyLines:
                plt.plot([l]*10, np.linspace(0, 1.2, 10), 'b:')
            
        # Plot badlines, if set
        if len(self.badLines) > 0:
            for line in self.badLines:
                c=patches.Rectangle((line[0], 0), (line[1]-line[0]), 1.2, fill=True, edgecolor=(0.9, 0.9, 0.9), 
                            facecolor=(0.9, 0.9, 0.9), linewidth=1)
                plt.gca().add_patch(c)            
            
        # Finish drawing the object spectrum plot
        plt.ylim(0, 1.2)
        plt.xlim(xMin,xMax)
        plt.ylabel("Relative Flux")
        plt.xlabel("Wavelength (Angstroms)")
        plt.legend(loc="upper right")
        
        plt.show()
        #plt.draw()
        
    def redrawPlot(self):
        #ipshell()
        #sys.ex
        #print self.qualityRadioVar.get()
        #self.qualityRadioList[self.qualityRadioVar.get()].select()
        self.updatePlot(self.objSED, self.templates[self.templateRadioVar.get()], self.skySED,
        self.redshiftScaleVar.get(), tempLabel=os.path.split(templateLabels[self.templateRadioVar.get()])[-1],
        redrawFeatures=True, plotFeatures=self.plotFeatures)

#------------------------------------------------------------------------------------------------------------
# Pair of helper functions for fitting SED normalisation
# p0 is list, [0] = normalisation
def fitSEDNorm(p, modelFluxes):
    result=p*modelFluxes
    return result

def fitSEDNormErrFunc(p, modelFluxes, observedFluxes):
    x=fitSEDNorm(p, modelFluxes)-observedFluxes
    chiSq=np.sum(x**2)   # not really chi sq, duh
    return chiSq
    
#------------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    
    if len(sys.argv) < 3:
        
        print("Run: rss_mos_visual_inspector <spec1d object spectra .fits [wildcards allowed]> ... <outputDir>")
        print("e.g. rss_mos_visual_inspector 1DSpec_2DSpec_stackAndExtract/1D_*.fits results")
    
    else:
        
        objectSpecFileNames=sys.argv[1:-1]
        outDir=sys.argv[-1]

        print("Spectra to be checked: ")
        print(objectSpecFileNames)
        
        root = tkinter.Tk()
        root.title("RSSMOSPipeline - Visual Inspector")
        app = App(root, objectSpecFileNames, outDir)
        root.mainloop()
