#!python
#
# Script to run fraunhofer.fit() on a spectrum

import os
import sys
import datetime
import time
import doppler
from astropy.io import fits
from astropy.table import Table
from argparse import ArgumentParser
from dlnpyutils import utils as dln
import numpy as np
import subprocess
import logging
import traceback
from fraunhofer import specfit

# Main command-line program
if __name__ == "__main__":
    parser = ArgumentParser(description='Run Fraunhofer fitting on spectra')
    parser.add_argument('files', type=str, nargs='+', help='Spectrum FITS files')
    parser.add_argument('-e','--elem', type=str, nargs=1, default='', help='List of elements to fit, only for fit()')    
    parser.add_argument('-f','--fpars', type=str, nargs=1, default='', help='List of parameters to fit, only fit_lsq()')
    parser.add_argument('-i','--init', type=str, nargs=1, default='', help='Initial parameters to use')    
    parser.add_argument('--flim', type=str, nargs=1, default='', help='Limits on fitted parameters')
    parser.add_argument('-o','--outfile', type=str, nargs=1, default='', help='Output filename')
    parser.add_argument('--figfile', type=str, nargs=1, default='', help='Figure filename')
    parser.add_argument('-d','--outdir', type=str, nargs=1, default='', help='Output directory')        
    parser.add_argument('-l','--list', type=str, nargs=1, default='', help='Input list of FITS files')
    parser.add_argument('-p','--plot', action='store_true', help='Save the plots')
    parser.add_argument('--vmicro', action='store_true', help='Fit vmicro')
    parser.add_argument('--vsini', action='store_true', help='Fit vsini')    
    parser.add_argument('-r','--reader', type=str, nargs=1, default='', help='The spectral reader to use')
    parser.add_argument('--alinefile', type=str, nargs=1, default='', help='Atomic linelist to use')
    parser.add_argument('--mlinefile', type=str, nargs=1, default='', help='Molecular linelist to use')
    parser.add_argument('-v','--verbose', type=int, nargs='?', default=0, const=1, help='Verbosity level (0, 1, 2)')

    args = parser.parse_args()

    # Get input parameters
    t0 = time.time()
    files = args.files
    inpoutfile = dln.first_el(args.outfile)
    inpfigfile = dln.first_el(args.figfile)
    outdir = dln.first_el(args.outdir)
    if outdir == '': outdir = None
    verbose = args.verbose
    reader = dln.first_el(args.reader)
    if reader == '': reader = None
    saveplot = args.plot
    inlist = dln.first_el(args.list)
    fitvmicro = args.vmicro
    fitvsini = args.vsini
    alinefile = args.alinefile
    mlinefile = args.mlinefile
    # Elements to fit
    elem = dln.first_el(args.elem)
    if elem != '':
        elem = elem.split(',')
        elem = list(np.char.array(elem).upper())
        if elem[0]=='' or elem[0].lower()=='none':
            elem = []
    else:
        elem = None
    # Parameters to fit
    fpars = dln.first_el(args.fpars)
    if fpars != '':
        fpars = fpars.split(',')
        fpars = list(np.char.array(fpars).upper())
    else:
        fpars = None
    # Fitted parameter limits
    flim = dlkn.first_el(args.flim)
    if flim != '':
        flim = flim.split(',')
        fparamlims = list(np.char.array(flim).upper())
    else:
        fparamlims = None    
    # Initial parameter values dictionary
    init = dln.first_el(args.init)
    if init != '':
        init = init.split(',')
        params = {}
        for k,init1 in enumerate(init):
            if init1.find(':') != -1:
                arr = init1.split(':')
            elif init1.find('=') != -1:
                arr = init1.split('=')
            else:
                raise ValueError('Use format key=value or key:value')
            params[str(arr[0]).upper()] = float(arr[1])
    else:
        params = None

    # Set up the logger
    logger = dln.basiclogger()
    logger.handlers[0].setFormatter(logging.Formatter("%(asctime)s [%(levelname)-5.5s]  %(message)s"))
    logger.handlers[0].setStream(sys.stdout)
    
    now = datetime.datetime.now()
    start = time.time()
    if verbose>0:
        logger.info("Start: "+now.strftime("%Y-%m-%d %H:%M:%S"))
        logger.info(" ")
        
        
    # Load files from a list
    if (len(files)==0) & (inlist!=''):
        # Check that file exists
        if os.path.exists(inlist) is False:
            raise ValueError(inlist+' NOT FOUND')
        # Read in the list
        files = dln.readlines(inlist)
    nfiles = len(files)

    # Outfile and figfile can ONLY be used with a SINGLE file
    if (inpoutfile!='') & (nfiles>1):
        raise ValueError('--outfile can only be used with a SINGLE input file')
    if (inpfigfile!='') & (nfiles>1):
        raise ValueError('--figfile can only be used with a SINGLE input file')

    if (verbose>0) & (nfiles>1):
        logger.info('--- Running Fraunhofer Fit on %d spectra ---' % nfiles)

    # Loop over the files
    for i,f in enumerate(files):
        # Check that the file exists
        if os.path.exists(f) is False:
            logger.info(f+' NOT FOUND')
            continue

        # Load the spectrum
        spec = doppler.read(f,format=reader)
        
        if (verbose>0):
            if (nfiles>1):
                if (i>0): logger.info('')
                logger.info('Spectrum %3d:  %s  S/N=%6.1f ' % (i+1,f,spec.snr))
            else:
                logger.info('%s  S/N=%6.1f ' % (f,spec.snr))
            logger.info(' ')

        # Save the figure
        figfile = None
        if (nfiles==1) & (inpfigfile!=''):
            figfile = inpfigfile
        if (inpfigfile=='') & (saveplot is True):
            fdir,base,ext = doppler.utils.splitfilename(f)
            figfile = base+'_fraunhofer.pdf'
            if outdir is not None: figfile = outdir+'/'+figfile
            if (outdir is None) & (fdir != ''): figfile = fdir+'/'+figfile 
            
        # HAVE A FLAG TO CALL FIT_LSQ() DIRECTLY????
        
        # Run Fraunhofer
        try:
            if fpars is None:
                out, model = specfit.fit(spec,params=params,elem=elem,
                                         fitvmicro=fitvmicro,fitvsini=fitvsini,
                                         figfile=figfile,fparamlims=fparamlims,
                                         alinefile=alinefile,mlinefile=mlinefile,
                                         logger=logger,verbose=verbose)

            # Run fit_lsq() directly
            else:
                out, model = specfit.fit_lsq(spec,params,fitparams=fpars,
                                             fparamlims=fparamlims,verbose=verbose,
                                             alinefile=alinefile,mlinefile=mlinefile,
                                             logger=logger)
            
            # Save the output
            if inpoutfile!='':
                outfile = inpoutfile
            else:
                fdir,base,ext = doppler.utils.splitfilename(f)
                outfile = base+'_fraunhofer.fits'
                if outdir is not None: outfile = outdir+'/'+outfile
                if (outdir is None) & (fdir != ''): outfile = fdir+'/'+outfile
            if verbose>0:
                logger.info('Writing output to '+outfile)
            if os.path.exists(outfile): os.remove(outfile)
            Table(out).write(outfile)
            # append best model
            hdulist = fits.open(outfile)
            hdu = fits.PrimaryHDU(model.flux)
            hdulist.append(hdu)
            hdulist.writeto(outfile,overwrite=True)
            hdulist.close()

        # Handle exceptions
        except Exception as e:
            if verbose is True:
                print('hofer failed on '+f+' '+str(e))
                traceback.print_exc()
        
            
    now = datetime.datetime.now()
    if verbose>0:
        logger.info(" ")
        logger.info("End: "+now.strftime("%Y-%m-%d %H:%M:%S"))
        logger.info("elapsed: %0.1f sec." % (time.time()-start))
        
