#!/usr/bin/env python3
import os,sys,seaborn,numpy,re,math,matplotlib,argparse
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def isfloat(value):
    """
    Check if value is a float or not.
    """
    try:
      float(value)
      return True
    except ValueError:
      return False

def getshift(explist,left,right,slope,random=False):
    """
    Calculate velocity shift do to distortion depending on model used.

    Parameters
    ----------
    left : float
      Minimum wavelength of the fitting region
    right : float
      Maximum wavelength of the fitting region
    slope : float
      Slope of the long-range distortion model
    random : boolean (default: False)
      Whether random slope should be calculated

    Returns
    -------
    shift : float
      Velocity shift due to long-range distortion effect.
    """
    # Initialise distortion shift value
    shift = 0
    # Calculate central wavelength of fitting region
    middle = (left+right)/2
    # Initialize total shift and exposure time parameter
    sumshift = sumcount = 0
    # Loop in the list of UVES exposures until break is called
    for l in range(len(explist)):
        # Define starting wavelength 
        wbeg = float(explist['WMIN'][l])
        # Define central wavelength (in Angstrom) from value in the list of settings (in nm)
        cent = float(explist['WMID'][l])
        # Define ending wavelength (in Angstrom) from value in the list of settings (in nm)
        wend = float(explist['WMAX'][l])
        if wbeg < left and right < wend:
            sumshift = sumshift + numpy.sqrt(float(explist['EXPTIME'][l])) * slope * (middle-cent)
            sumcount = sumcount + numpy.sqrt(float(explist['EXPTIME'][l]))
    shift = sumshift / sumcount
    return '{0:.4f}'.format(float(shift)/1000.)    

class DistModel:
    '''
    Create distortion model.

    Examples
    --------
    From executable:

    >>> alphaDist model --explist xshooter_exposures.dat --model finished.13
    '''
    def __init__(self,fort13,explist,slope=1,output=None):

        '''
        Main method to execute the calculations

        Parameters
        ----------
        fort13 : str
          Path to input Voigt profile model
        explist : str
          Path to exposure list.
        '''
        explist = numpy.genfromtxt(explist,names=True,skip_header=1,dtype=object)
        ions  = numpy.loadtxt(fort13,delimiter='\n',dtype=str)
        trans = []
        flag  = 0
        for line in ions:
            if '*' in line and flag==1:
                break
            elif '*' in line:
                flag = 1
            else:
                trans.append((float(line.split()[2])+float(line.split()[3]))/2)
        fig = plt.figure(figsize=(7,4),dpi=300,frameon=False)
        fig.patch.set_alpha(0.0)
        plt.subplots_adjust(left=0.12, right=0.95, bottom=0.12, top=0.95, hspace=0.05, wspace=0.1)
        ax = plt.subplot(211)
        # Distortion models and fitting regions
        exp1 = numpy.empty((0,4))
        exp2 = numpy.empty((0,4))
        for k in range(len(explist)):
            wbeg = float(explist['WMIN'][k])
            cent = float(explist['WMID'][k])
            wend = float(explist['WMAX'][k])
            time = float(explist['EXPTIME'][k])
            if cent not in [float(j) for j in exp1[:,2]]:
                color = 'blue' if explist['BRANCH'][k]=='VIS' else 'red'
                exp1 = numpy.vstack((exp1,[color,wbeg,cent,wend]))
            exp2 = numpy.vstack((exp2,[time,wbeg,cent,wend]))
        for k in range (len(exp1)):
            color   = exp1[k,0]
            wastart = float(exp1[k,1])
            wacent  = float(exp1[k,2])
            waend   = float(exp1[k,3])
            x       = numpy.arange(wastart,waend,1)
            y       = slope*(x-wacent)
            ax.plot(x,y,color=color,lw=1.5,zorder=3)
            ax.axvline(x=wastart,color='black',zorder=2,lw=1,ls='dotted')
            ax.axvspan(wastart,waend,facecolor='yellow',alpha=0.1,zorder=1)
            ax.axvline(x=waend,color='black',zorder=2,lw=1,ls='dotted')
            ax.scatter(wacent,0,color='black',marker='d',edgecolors='none',s=25,zorder=4)
        ax.axhline(y=0,ls='dotted',color='black')
        plt.setp(ax.get_xticklabels(), visible=False)
        plt.ylabel('Velocity shift (m/s)',size=10)
        for i in range(len(trans)):
            ax.axvline(x=trans[i],ymin=0.85,ymax=0.9,color='red',lw=1)
        # Correction function
        ax = plt.subplot(212,sharex=ax,sharey=ax)
        for k in range (len(exp1)):
            wastart = float(exp1[k,1])
            wacent  = float(exp1[k,2])
            waend   = float(exp1[k,3])
            ax.axvline(x=wastart,color='black',zorder=2,lw=1,ls='dotted')
            ax.axvspan(wastart,waend,facecolor='yellow',alpha=0.1)
            ax.axvline(x=waend,color='black',zorder=2,lw=1,ls='dotted')
        xmin = min([float(wmin) for wmin in explist['WMIN']])
        xmax = max([float(wmax) for wmax in explist['WMAX']])
        x,y = numpy.arange(xmin,xmax,1),[]
        for wa in x:
            vdist,texp = [],[]
            for i in range (len(exp2)):
                if exp2[i,1] < wa < exp2[i,3]:
                    vdist.append(slope*(wa-exp2[i,2]))
                    texp.append(numpy.sqrt(exp2[i,0]))
            if vdist!=[]:
                y.append(sum(numpy.array(texp)*numpy.array(vdist))/sum(numpy.array(texp)))
            else:
                y.append(None)
        ax.plot(x,y,color='black',lw=1.5,zorder=1)
        ax.axhline(y=0,ls='dotted',color='black')
        for i in range(len(trans)):
            ax.axvline(x=trans[i],ymin=0.85,ymax=0.9,color='red',lw=1)
        plt.xlabel('Wavelength ($\mathrm{\AA}$)',size=12)
        plt.ylabel('Velocity shift (m/s)',size=12)
        plt.show() if output==None else plt.savefig(output)
        plt.close(fig)

class AlphaDist:
    '''
    Self-contained class to do all the distortion calculations.

    Examples
    --------
    From executable:

    >>> alphaDist --explist xshooter_exposures.dat --model finished.13
    '''
    def __init__(self,fort13,explist,distmin=-1,distmax=1,distmid=0,distsep=0.1):
        '''
        Main method to execute the calculations

        Parameters
        ----------
        fort13 : str
          Path to input Voigt profile model
        explist : str
          Path to exposure list.
        distmin : float
          Minimum slope to compute. Default is -1.
        distmax
          Maximum slope to compute. Default is 1.
        distmid
          Starting slope value. Default is 0
        distsep
          Interval between consecutive slope values. Default is 0.1.
        '''
        local = os.getcwd()
        # Add both inputs into self object
        self.fort13 = fort13.split('/')[-1]
        self.explist = numpy.genfromtxt(explist,names=True,skip_header=1)
        self.path = local+'/'+'/'.join(fort13.split('/')[:-1])
        # Move to model's directory
        os.chdir(self.path)
        # Define slope list
        slope1 = numpy.arange(distmid,distmax+distsep,+distsep)
        slope2 = numpy.arange(distmid,distmin-distsep,-distsep)
        distlist = numpy.hstack((slope1,slope2))
        # Extract path where fort.13 is located
        # Store atomic data list
        self.makeatomlist(self.path+'/atom.dat')
        # Create data folder in distortion repository
        #os.system('mkdir -p distortion/data')
        # Loop through all slope values
        for self.slope in distlist:
            # Move to default path
            os.chdir(self.path)
            # Create string for slope folder name
            self.distortion = '0.000' if round(self.slope,3)==0 else \
                              str('%.3f'%self.slope).replace('-','m') if self.slope<0 else \
                              'p'+str('%.3f'%self.slope)
            # Create distortion slope directory
            os.system('mkdir -p distortion/'+self.distortion)
            os.system('ln -s ../../data distortion/%s/'%self.distortion)
            print('> Fit distortion '+self.distortion)
            # Prepare distorted model
            self.create_model(distmid,distsep)
            # Move to distortion slope directory
            os.chdir(self.path+'/distortion/'+self.distortion)
            # Create initial fort.13 model
            self.write_model()
            # Fit the model
            self.fit_system()
            # Convert fort.26 output to fort.13 format
            self.convert26to13()
        os.chdir(local)
            
    def makeatomlist(self,atompath):
        """
        Store data from atom.dat
        """
        self.atom = numpy.empty((0,6))
        atomdat = numpy.loadtxt(atompath,dtype='str',delimiter='\n')
        for element in atomdat:
            l       = element.split()
            i       = 0      if len(l[0])>1 else 1
            species = l[0]   if len(l[0])>1 else l[0]+l[1]
            wave    = 0 if len(l)<i+2 else 0 if isfloat(l[i+1])==False else l[i+1]
            f       = 0 if len(l)<i+3 else 0 if isfloat(l[i+2])==False else l[i+2]
            gamma   = 0 if len(l)<i+4 else 0 if isfloat(l[i+3])==False else l[i+3]
            mass    = 0 if len(l)<i+5 else 0 if isfloat(l[i+4])==False else l[i+4]
            alpha   = 0 if len(l)<i+6 else 0 if isfloat(l[i+5])==False else l[i+5]
            if species not in ['>>','<<','<>','__']:
                self.atom = numpy.vstack((self.atom,[species,wave,f,gamma,mass,alpha]))

    def atominfo(self,atomID):
        """
        Find transition in atom list and extract information.
    
        Parameters
        ----------
        atomID : string
          Name of the transition, written as ion_wavelength.
        """
        target = [0,0,0,0,0]
        atomID = atomID.split('_')
        for i in range(len(self.atom)):
            element     = self.atom[i,0]
            wavelength  = self.atom[i,1]
            oscillator  = self.atom[i,2]
            gammavalue  = self.atom[i,3]
            qcoeff      = self.atom[i,5]
            if (len(atomID)>1 and element==atomID[0]
                and abs(float(wavelength)-float(atomID[1]))<abs(float(target[1])-float(atomID[1]))) \
               or (len(atomID)==1 and element==atomID[0]):
               target = [element,wavelength,oscillator,gammavalue,qcoeff] 
        if target==[0,0,0,0,0]:
            print(atomID,'not identifiable...')
            quit()
        return target
    
    def create_model(self,distmid,distsep):
        """
        Create distortion model from original fort.13.
        """
        # Store header.dat file
        read_head = numpy.loadtxt('header.dat',dtype='str',delimiter='\n',ndmin=1)
        # Read original fort.13 file
        read_fort = open(self.fort13,'r')
        # If slope different than first slope, overwrite fort.13
        if self.slope!=distmid:
            # Calculate slope of previous step
            i = self.slope - distsep if self.slope > distmid else self.slope + distsep
            # Define name of distortion folder
            slope = '0.000' if round(i,3)==0 else str('%.3f'%i).replace('-','m') if i<0 else 'p'+str('%.3f'%i)
            # Read last fort.13 fit
            read_fort = open(self.path+'/distortion/'+slope+'/fort_fit.13','r')
        # Store fort.13 in array
        read_fort = [line.strip() for line in read_fort]
        # Initialise array to store original fort.13 header 
        fort_header_old  = numpy.empty((0,8))
        # Initialise array to store content of original fort.13
        fort_content_old = numpy.empty((0,8))
        # Loop through each line in stored fort.13
        i = flag = 0
        while i < len(read_fort):
            # Check if empty line occurs after list of components
            if flag==2 and (read_fort[i]=='' or (read_fort[i].split()[0]=='>>' and read_fort[i].split('!')[0].split()[-1]=='1')):
                # End the loop
                break
            # Check line corresponds to an asterisks
            if read_fort[i]=='*':
                # Increment flag value
                flag = flag+1
                # Jump to next line
                i = offset = i + 1
            # If flag is 1 and line not commented, read fitting regions information
            if flag==1 and read_fort[i][0]!='!':
                # Split line by spaces
                vals = read_fort[i].replace('!',' ').split()
                # Check if region data is copied in data folder
                #if os.path.exists(self.path+'/distortion/data/%s'%vals[0].split('/')[-1])==False:
                #    # Copy data associated to fitting region
                #    os.system('cp %s %s/distortion/data/'%(vals[0],self.path))
                # Extract and redefine path to data file
                val0 = 'data/'+vals[0].split('/')[-1]
                val1 = int(vals[1])
                val2 = '%.2f'%float(vals[2])
                val3 = '%.2f'%float(vals[3])
                val4 = vals[4].split('=')[0]+'='+str('%5.8f'%float(vals[4].split('=')[1]))
                val5 = read_head[i-offset].split()[0]
                val6 = '' if len(read_head[i-offset].split())==1 else read_head[i-offset].split()[1]
                val7 = '' if len(read_head[i-offset].split())==1 else read_head[i-offset].split()[2]
                fort_header_old = numpy.vstack((fort_header_old,[val0,val1,val2,val3,val4,val5,val6,val7]))
            if flag==2 and read_fort[i][0]!='!':
                vals = read_fort[i].split()
                val0 = vals[0]+' '+vals[1] if len(vals[0])==1 else vals[0]
                k = 1 if len(vals[0])==1 else 0
                val1 = '%.5f'%float(vals[k+1][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+1][-2:]))
                val1 = val1+" ".join(re.findall("[a-zA-Z]+",vals[k+1][-2:]))
                zabs = float(vals[k+2][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+2][-2:]))
                val2 = '%.7f'%zabs
                val2 = val2+" ".join(re.findall("[a-zA-Z]+",vals[k+2][-2:]))
                val3 = '%.4f'%float(vals[k+3][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+3][-2:]))
                val3 = val3+" ".join(re.findall("[a-zA-Z]+",vals[k+3][-2:]))
                div  = 10**(-6) if 'E-0' in vals[k+4] else 1
                val4 = '0.000' if '*' in vals[k+4] else '%.3f'%(float(vals[k+4][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+4][-2:]))/div)
                val4 = val4+" ".join(re.findall("[a-zA-Z]+",vals[k+4][-2:]))
                val5 = '%.2f'%float(vals[k+5][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+5][-2:]))
                val5 = val5+" ".join(re.findall("[a-zA-Z]+",vals[k+5][-2:]))
                val6 = '%.2E'%float(vals[k+6][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+6][-2:]))
                val6 = val6+" ".join(re.findall("[a-zA-Z]+",vals[k+6][-2:]))
                val7 = str(int(float(vals[k+7][:-2]+re.compile(r'[^\d.-]+').sub('',vals[k+7][-2:]))))
                val7 = val7+" ".join(re.findall("[a-zA-Z]+",vals[k+7][-2:]))
                fort_content_old = numpy.vstack((fort_content_old,[val0,val1,val2,val3,val4,val5,val6,val7]))
            i = i + 1
        # Store selected fort.13 content, and shift components
        fort_content_new = numpy.empty((0,8))
        for i in range (len(fort_content_old)):
            fort_content_new = numpy.vstack((fort_content_new,fort_content_old[i]))
        self.model = 'turbulent' if 1 in numpy.array(fort_content_old[:,5],dtype=float) else 'thermal'
        # Store selected fort.13 header, and shifts
        store_shift = []
        fort_header_new = numpy.empty((0,8))
        for i in range (len(fort_header_old)):
            self.armflag = 'null'
            trans = fort_header_old[i,-3]
            wrest = float(self.atominfo(trans)[1])
            shift = getshift(self.explist,float(fort_header_old[i,2]),float(fort_header_old[i,3]),self.slope)
            fort_header_new = numpy.vstack((fort_header_new,fort_header_old[i]))
            store_shift.append(shift)
        # Implement fix shift values in fort arrays
        if self.distortion not in ['','0.000']:
            for p in range (len(store_shift)):
                shift = ['>>','1.00000FF','0.0000000FF',store_shift[p]+'FF','0.000FF','0.00','0.00E+00',p+1]
                fort_content_new = numpy.vstack((fort_content_new,shift))
        self.fort_header_new = fort_header_new
        self.fort_content_new = fort_content_new
            
    def write_model(self):
        """
        Write input distortion model in fort_ini.13 file.
        """
        write_header = open('header.dat','w')
        write_fort = open('fort_ini.13','w')
        write_fort.write('   *\n')
        for i in range (len(self.fort_header_new)):
            datalength = [len(self.fort_header_new[k,0]) for k in range (len(self.fort_header_new))]
            datalength = "{0:<"+str(max(datalength))+"}"
            write_fort.write(datalength.format(self.fort_header_new[i,0])+' ')
            write_fort.write('{0:>5}'.format(self.fort_header_new[i,1])+' ')
            write_fort.write('{0:>10}'.format('%.2f'%float(self.fort_header_new[i,2]))+' ')
            write_fort.write('{0:>10}'.format('%.2f'%float(self.fort_header_new[i,3]))+' ')
            write_fort.write('{0:<17}'.format(self.fort_header_new[i,4])+' ! ')
            write_fort.write('{0:<15}'.format(self.fort_header_new[i,5])+' ')
            write_header.write('{0:<15}'.format(self.fort_header_new[i,5])+' ')
            if self.fort_header_new[i,6]!='':
                write_fort.write('{0:<15}'.format(self.fort_header_new[i,6])+' ')
                write_header.write('{0:<15}'.format(self.fort_header_new[i,6])+' ')
                write_fort.write('{0:<15}'.format(self.fort_header_new[i,7]))
                write_header.write('{0:<15}'.format(self.fort_header_new[i,7]))
            write_header.write('\n')
            write_fort.write('\n')
        write_fort.write('  *\n')
        for i in range (len(self.fort_content_new)):
            val = self.fort_content_new[i]
            write_fort.write('   '+'{0:<5}'.format(val[0])+' ')
            write_fort.write('{0:>5}'.format(val[1].split('.')[0])+'.'+'{0:<7}'.format(val[1].split('.')[1])+' ')
            write_fort.write('{0:>3}'.format(val[2].split('.')[0])+'.'+'{0:<9}'.format(val[2].split('.')[1])+' ')
            write_fort.write('{0:>4}'.format(val[3].split('.')[0])+'.'+'{0:<6}'.format(val[3].split('.')[1])+' ')
            write_fort.write('{0:>5}'.format(val[4].split('.')[0])+'.'+'{0:<6}'.format(val[4].split('.')[1])+' ')
            write_fort.write('{0:>5}'.format(val[5].split('.')[0])+'.'+'{0:<2}'.format(val[5].split('.')[1])+' ')
            write_fort.write('  %s '%val[6])
            write_fort.write('{0:>2}'.format(val[7])+' ! ')
            write_fort.write('{0:>4}'.format(i+1))
            write_fort.write('\n')
        write_header.close()
        write_fort.close()

    def fit_system(self):
        """
        Fit original fort.13 distortion model.
        """
        # Create alias to atom.dat
        if os.path.exists('atom.dat'):
            os.system('rm atom.dat')
        os.system('ln -s ../../atom.dat')
        # Create alias to vp_setup.dat
        if os.path.exists('vp_setup.dat'):
            os.system('rm vp_setup.dat')
        os.system('ln -s ../../vp_setup.dat')
        # Prepare model folders and run VPFIT
        os.environ['ATOMDIR']='./atom.dat'
        os.environ['VPFSETUP']='./vp_setup.dat'
        opfile = open('fitcommands','w')
        opfile.write('f\n\n\nfort_ini.13\n')
        for line in self.fort_header_new:
            if '.fits' in line[0]:
                opfile.write('\n')
        opfile.write('n\nn\n')
        opfile.close()
        os.system('vpfit < fitcommands > termout')

    def convert26to13(self):
        """
        Create fort_fit.13 after fitting fort_ini.13 completed.
        """
        flag26 = flag18 = 0
        final = open('fort_fit.13','w')
        final.write('   *\n')
        line26 = numpy.loadtxt('fort.26',dtype='str',delimiter='\n')
        for i in range(len(line26)):
            if 'Stats:' in line26[i]:
                flag26 = 1
            if line26[i][0:2]!='%%':
                break
            else:
                final.write(line26[i].replace('%% ','')+'\n')
        final.write('  *\n')
        line18 = numpy.loadtxt('fort.18',dtype='str',delimiter='\n')
        for i in range(len(line18)-1,0,-1):
            if 'statistics for whole fit:' in line18[i]:
                flag18 = 1
            if 'chi-squared' in line18[i]:
                chisq   = '%.4f'%float(line18[i].split('(')[1].split(',')[0])
                chisqnu = '%.3f'%float(line18[i].split('(')[0].split(':')[1])
                ndf     = '%.0f'%float(line18[i].split(')')[0].split(',')[1])
                print('  | chisq=%s, ndf=%s, chisq_nu=%s'%(chisq,ndf,chisqnu))
                a = i + 2
                break
        for i in range(a,len(line18)):
            if len(line18[i])==1:
                break
            final.write(line18[i]+'\n')
        final.close()

class PlotCurve:
    """
    Plot chi-square curves.

    Notes
    -----
    The operation must be executed from the path where both thermal and
    turbulent folder are present.

    Examples
    --------
    From executable:

    >>> alphaDist --curve --output curve --thermal thermal/ --turbulent turbulent/ --distmin -1.5 --distmax 1 --distsep 0.1 --xmin -0.5 --xmax0.5 --output plot_curves

    From python script:

    >>> import alpha
    >>> alpha.PlotCurve(thermal='thermal/',turbulent='turbulent/',distmin=-1.5,distmax=1,distsep=0.1,xmin=-0.5,xmax=0.5)
    """
    def __init__(self,thermal=None,turbulent=None,distmin=-1,distmax=1,distsep=0.1,xmin=None,xmax=None,output=None):
        '''
        Main method to do the plotting.

        Parameters
        ----------
        thermal : str
          Directory path of thermal model
        turbulent : str
          Directory path of turbulent model
        distmin : float
          Minimum slope to compute. Default is -1.
        distmax : float
          Maximum slope to compute. Default is 1.
        distsep : float
          Interval between consecutive slope values. Default is 0.1.
        xmin : float
          Minimum slope of fitting range
        xmax : float
          Maximum slope of fitting range
        output :str
          Output figure filename 
        '''
        slope1 = numpy.arange(0,distmax+distsep,+distsep)
        slope2 = numpy.arange(0,distmin-distsep,-distsep)
        distlist = numpy.hstack((slope1,slope2))
        fitres = self.extract_results(thermal,turbulent,distlist)
        self.plot_curves(fitres,xmin,xmax,output,plot_range=[distmin,distmax])
        
    def extract_results(self,thermal,turbulent,distlist):
        """
        Extract chi-square from results
        """
        start = True
        for i in distlist:
            self.stats = {}
            self.alpha = {}
            # Calculate string name of distortion folder
            slope = '0.000' if round(i,3)==0 else str('%.3f'%i).replace('-','m') if '-' in str(i) else 'p'+str('%.3f'%i)
            # Extract thermal fit results
            self.readfort26('ther',thermal+'/distortion/'+slope+'/fort.26')
            # Extract turbulent fit results
            self.readfort26('turb',turbulent+'/distortion/'+slope+'/fort.26')
            # Determine MoM coefficients
            mom_df,mom_chisq = None,None
            ther_chisq,ther_df,ther_n = self.stats['ther']
            turb_chisq,turb_df,turb_n = self.stats['turb']
            if ther_chisq!=None and turb_chisq!=None:
                k         = ther_n - ther_df
                ther_AICc = ther_chisq + 2*k + 2*k*(k+1)/(ther_n-k-1)
                k         = turb_n - turb_df
                turb_AICc = turb_chisq + 2*k + 2*k*(k+1)/(turb_n-k-1)
                csmin     = min([ther_AICc,turb_AICc])
                k1        = math.exp(-(ther_AICc-csmin)/2)
                k2        = math.exp(-(turb_AICc-csmin)/2)
                k         = k1 + k2
                k1        = k1/k
                k2        = k2/k            
                mom_df    = k1 * ther_df    + k2 * turb_df
                mom_chisq = k1 * ther_chisq + k2 * turb_chisq
            if start==True:
                length = 7*(len(self.alpha)+1)
                fitres = numpy.empty((0,length))
                start  = False
            results = [round(i,3),ther_df,ther_chisq,turb_df,turb_chisq,mom_df,mom_chisq]
            for label in self.alpha.keys():
                mom_alpha,mom_error = None,None
                zabs,ther_alpha,ther_error,turb_alpha,turb_error = self.alpha[label]
                if ther_alpha!=None and turb_alpha!=None:
                    mom_alpha = k1 * ther_alpha + k2 * turb_alpha
                    mom_error = numpy.sqrt(k1*ther_error**2 + k2*turb_error**2 + k1*ther_alpha**2 + k2*turb_alpha**2 - mom_alpha**2)
                results.extend([zabs,ther_alpha,ther_error,turb_alpha,turb_error,mom_alpha,mom_error])
            # Check dimensionality of results
            imiss = abs(len(results)-fitres.shape[1])
            if len(results)>fitres.shape[1]: 
                fitres = numpy.hstack((fitres,numpy.reshape([None]*len(fitres)*imiss,(len(fitres),imiss))))
            if len(results)<fitres.shape[1]:
                results.extend([None]*imiss)
            fitres = numpy.vstack((fitres,results))
        return fitres
    
    def readfort26(self,model,fortpath):
        """
        Extract fitting results from fort.18 output.
        """
        if os.path.exists(fortpath):
            daoaun = 1.
            vpsetup = fortpath.replace(fortpath.split('/')[-1],'vp_setup.dat')
            for line in numpy.loadtxt(vpsetup,dtype=str,delimiter='\n',comments='!'):
                if 'daoaun'in line:
                    daoaun = float(line.split()[1])
            flag = 0
            fort26 = numpy.loadtxt(fortpath,dtype='str',delimiter='\n')
            for i in range(len(fort26)-1,0,-1):
                # Check if ion is one letter is define offset index accordingly
                s = 1 if len(fort26[i].split()[0])==1 else 0
                # Determine absorption redshift
                zabs = float(re.compile(r'[^\d.-]+').sub('',fort26[i].split()[1+s]))
                if 'Stats:' in fort26[i]:
                    n  = float(fort26[i].split()[4])
                    df = float(fort26[i].split()[5])
                    chisq_nu = float(fort26[i].split()[3])
                    chisq = chisq_nu * df
                    self.stats[model]=[chisq,df,n]
                    break
                # Check if component is alpha anchor
                if 'q' in str(fort26[i].split()[7+s]):
                    label = " ".join(re.findall("[a-zA-Z]+",fort26[i].split()[7+s][-2:]))
                    zabs  = float(fort26[i].split()[1+s][:-2])
                    daoa  = float(fort26[i].split()[7+s].split('q')[0])*daoaun/1e-5
                    error = float(fort26[i].split()[8+s].split('q')[0])*daoaun/1e-5
                    if label not in self.alpha.keys():
                        if model=='ther':
                            self.alpha[label] = [zabs,daoa,error,None,None]
                        if model=='turb':
                            self.alpha[label] = [zabs,None,None,daoa,error]
                    if label in self.alpha.keys():
                        if model=='ther':
                            self.alpha[label][1:3] = [daoa,error]
                        if model=='turb':
                            self.alpha[label][-2:] = [daoa,error]
        else:
            self.stats[model]=[None,None,None]                

    def order_list(self,x,y,yerr=None):
        """
        Order distortion results in slope order and remove duplicate
        """
        order = numpy.argsort(x)
        x,y = x[order],y[order]
        yerr = None if yerr is None else yerr[order]
        idxs = [i for i in range(len(x)-1) if x[i]==x[i+1]]
        x = numpy.delete(x,idxs)
        y = numpy.delete(y,idxs)
        yerr = None if yerr is None else numpy.delete(yerr,idxs)
        return x,y,yerr
        
    def plot_curves(self,fitres,xmin,xmax,output,plot_range):
        """
        Do scatter plot of both chi-square and da/a results versus distortion slope
        """
        nrows = fitres.shape[1]/7
        plt.rc('font', size=2, family='sans-serif')
        plt.rc('axes', labelsize=10, linewidth=0.2)
        plt.rc('legend', fontsize=10, handlelength=10)
        plt.rc('xtick', labelsize=7)
        plt.rc('ytick', labelsize=7)
        plt.rc('lines', lw=0.2, mew=0.2)
        plt.rc('grid', linewidth=0.5)
        fig = plt.figure(figsize=(9,2.1*nrows),frameon=False,dpi=300)
        plt.style.context('seaborn-darkgrid')
        plt.style.use('seaborn')
        #plt.subplots_adjust(left=0.1, right=0.97, bottom=0.07, top=0.97, hspace=0.2, wspace=0.2)
        slopes = fitres[:,0]
        model = ['Thermal','Turbulent','Method of Moments']
        for j in range(len(model)):
            chisq = fitres[:,2+2*j]
            idxs  = numpy.where(chisq!=None)[0]
            x,y,_ = self.order_list(slopes[idxs],chisq[idxs])
            dx    = (max(slopes)-min(slopes))/20
            allchisq = numpy.vstack((fitres[:,2],fitres[:,4],fitres[:,6])).T
            idxs  = numpy.where(allchisq!=None)
            ymin  = 0 if len(y)==0 else min(y)-1 if min(y)==max(y) else allchisq[idxs].min()#min(y)
            ymax  = 1 if len(y)==0 else max(y)+1 if min(y)==max(y) else allchisq[idxs].max()#max(y)
            dy    = (ymax-ymin)/10
            ax    = plt.subplot(nrows,3,1+j,xlim=[min(slopes)-dx,max(slopes)+dx],ylim=[ymin-dy,ymax+3*dy]) 
            if len(y)>0:
                ax.errorbar(x,y,fmt='o',ms=4,markeredgecolor='none',ecolor='grey',alpha=0.7,color='black')
            if len(y)>1:
                fitmin = xmin if xmin==None else float(xmin) if isfloat(xmin) else float(xmin.split(':')[j])
                fitmax = xmax if xmax==None else float(xmax) if isfloat(xmax) else float(xmax.split(':')[j])
                self.fit_parabola(x,y,fitmin,fitmax,ymin,ymax,plot_range)
            y_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
            ax.yaxis.set_major_formatter(y_formatter)
            ax.set_title(model[j])
            plt.setp(ax.get_xticklabels(), visible=False)
            if j==0:
                ax.set_ylabel(r'$\chi^2_\mathrm{abs}$')
            else:
                plt.setp(ax.get_yticklabels(), visible=False)
            # Plot alpha systems in redshift order
            order = []
            for k in range(int(nrows)-1):
                temp = numpy.array(fitres[:,7+7*k],dtype=float)
                order.append(numpy.nanmean(temp))
            for num,k in enumerate(numpy.argsort(order)):
                temp  = numpy.array(fitres[:,7+7*k],dtype=float)
                zabs  = numpy.nanmean(temp)
                alpha = fitres[:,8+7*k+2*j]
                error = fitres[:,9+7*k+2*j]
                idxs  = numpy.where(alpha!=None)[0]
                x,y,yerr = self.order_list(slopes[idxs],alpha[idxs],error[idxs])
                allalpha,allerror = [],[]
                for i in numpy.hstack((fitres[:,8+7*k],fitres[:,8+7*k+2],fitres[:,8+7*k+4])):
                    if i!=None: allalpha.append(i)
                for i in numpy.hstack((fitres[:,9+7*k],fitres[:,9+7*k+2],fitres[:,9+7*k+4])):
                    if i!=None: allerror.append(i)
                ymin  = 0 if len(y)==0 else min(y)-1 if min(y)==max(y) else min(allalpha)-max(allerror)
                ymax  = 1 if len(y)==0 else max(y)+1 if min(y)==max(y) else max(allalpha)+max(allerror)
                dy    = (ymax-ymin)/10
                ax    = plt.subplot(nrows,3,4+3*num+j,xlim=[min(slopes)-dx,max(slopes)+dx],ylim=[ymin,ymax+dy])
                if len(y)>0:
                    ax.errorbar(x,y,yerr=yerr,fmt='o',ms=4,markeredgecolor='none',ecolor='grey',alpha=0.7,color='black',lw=0.5)
                if len(y)>1:
                    self.fit_linear(x,y,yerr,fitmin,fitmax,ymin,ymax,plot_range)
                y_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
                ax.yaxis.set_major_formatter(y_formatter)
                if j==0:
                    ax.set_ylabel(r'$\Delta\alpha/\alpha$ $(10^{-5})$'+'\n'+r'for $z_\mathrm{abs}=%.4f$'%zabs)
                else:
                    plt.setp(ax.get_yticklabels(), visible=False)
                if num==nrows-2:
                    ax.set_xlabel(r'Distortion slope (m/s/$\mathrm{\AA}$)')
                else:
                    plt.setp(ax.get_xticklabels(), visible=False)
        plt.tight_layout()
        plt.show() if output==None else plt.savefig(output)
        plt.close(fig)
        
    def fit_parabola(self,x,y,xmin,xmax,ymin,ymax,plot_range):
        """
        Fit parabola to the chi-square curves.
        """
        # Determine in-window label position
        xpos = numpy.average(plot_range)
        ypos = ymax+0.1*(ymax-ymin)
        # Define fitting range
        imin = 0 if xmin==None else numpy.where(x>=xmin)[0][0]
        imax = -1 if xmax==None else numpy.where(x<=xmax)[0][-1]+1
        x,y  = x[imin:imax],y[imin:imax]
        x = numpy.array(x,dtype=float)
        y = numpy.array(y,dtype=float)
        # Execute parabolic fit
        A = numpy.vander(x,3)
        (coeffs, residuals, rank, sing_vals) = numpy.linalg.lstsq(A,y,rcond=None)
        f = numpy.poly1d(coeffs)
        # Define fitting variables
        xfit = numpy.arange(-100,100,0.0001)
        imid = abs(f(xfit)-min(f(xfit))).argmin()
        isig = abs(f(xfit)-(min(f(xfit))+1)).argmin()
        xmid = xfit[imid]
        xsig = abs(xfit[isig]-xfit[imid])
        self.xm1sig = xmid-xsig
        self.xp1sig = xmid+xsig
        plt.plot(xfit,f(xfit),c='red',lw=1)
        print('Slope: {0:>8}+/-{1:<8}'.format('%.4f'%xmid,'%.4f'%xsig))
        #print('Chisq: {0:>12}'.format(self.residuals))
        plt.axvline(x=self.xm1sig,ls='dotted',color='blue',lw=1)
        plt.axvline(x=xmid,ls='dashed',color='red',lw=1)
        plt.axvline(x=self.xp1sig,ls='dotted',color='blue',lw=1)
        t1 = plt.text(xpos,ypos,r'$\chi^2_\mathrm{min}$ at %.4f $\pm$ %.4f'%(xmid,xsig),
                      color='red',fontsize=10,ha='center')
        t1.set_bbox(dict(color='white', alpha=0.7, edgecolor=None))
        self.slope = xmid
        self.slope_error = xsig
    
    def fit_linear(self,x,y,yerr,xmin,xmax,ymin,ymax,plot_range):
        """
        Do linear fit to da/a vs. distortion slope curves.
        """
        # Determine in-window label position
        xpos = numpy.average(plot_range)
        ypos = ymax
        # Define fitting range
        imin = 0 if xmin==None else numpy.where(x>=xmin)[0][0]
        imax = -1 if xmax==None else numpy.where(x<=xmax)[0][-1]+1
        x,y,yerr = x[imin:imax],y[imin:imax],yerr[imin:imax]
        # Execute parabolic fit
        x = numpy.array(x,dtype=float)
        y = numpy.array(y,dtype=float)
        yerr = numpy.array(yerr,dtype=float)
        def func(func,a,b):
            return a + b*x
        pars,cov = curve_fit(func,x,y,sigma=yerr)
        # Define fitting variables
        xfit = numpy.arange(-100,100,0.001)
        yfit = pars[0] + pars[1]*xfit
        plt.plot(xfit,yfit,color='red',lw=1)
        imid = abs(xfit-self.slope).argmin()
        imin = abs(xfit-(self.slope-self.slope_error)).argmin()
        imax = abs(xfit-(self.slope+self.slope_error)).argmin()
        plt.axvline(x=self.xm1sig,ls='dotted',color='blue',lw=1)
        plt.axvline(x=xfit[imid],ls='dashed',color='red',lw=1)
        plt.axvline(x=self.xp1sig,ls='dotted',color='blue',lw=1)
        plt.axhline(y=yfit[imid],ls='dashed',color='red',lw=1)
        plt.axhline(y=yfit[imax],ls='dotted',color='blue',lw=1)
        plt.axhline(y=yfit[imin],ls='dotted',color='blue',lw=1)
        alpha      = yfit[imid]
        alpha_stat = numpy.average(yerr)
        alpha_syst = abs(yfit[imax]-yfit[imid])
        t1 = plt.text(xpos,ypos,
                      r'$\Delta\alpha/\alpha$ = %.4f $\pm$ %.4f $\pm$ %.4f'%(yfit[imid],alpha_stat,alpha_syst),
                      color='red',fontsize=10,ha='center',va='top')
        t1.set_bbox(dict(color='white', alpha=0.7, edgecolor=None))
        print('Alpha: {0:>8}+/-{1:<6}+/-{2:<6}'.format('%.4f'%alpha,'%.4f'%alpha_stat,'%.4f'%alpha_syst))
                
if __name__=="__main__":
    # Define arguments
    parser = argparse.ArgumentParser(description='Alpha Distortion Estimator')
    parser.add_argument('operation',help='Operation to perform')
    parser.add_argument('--xmin',default=None,help='Minimum slope of fitting range')
    parser.add_argument('--xmax',default=None,help='Maximum slope of fitting range')
    parser.add_argument('--distmin',default=-1,type=float,help='Minimum slope to compute')
    parser.add_argument('--distmax',default=1,type=float,help='Maximum slope to compute')
    parser.add_argument('--distmid',default=0,type=float,help='Starting slope value')
    parser.add_argument('--distsep',default=0.1,type=float,help='Interval between consecutive slope values') 
    parser.add_argument('--explist',help='Exposure list. Must contain column EXPTIME, WMIN, WMAX and WMID.')
    parser.add_argument('--model',help='Input fort.13')
    parser.add_argument('--output',help='Output figure filename')
    parser.add_argument('--thermal',default='./',type=str,help='Directory path where to do the calculations')
    parser.add_argument('--turbulent',default='./',type=str,help='Directory path where to do the calculations')
    parser.add_argument('--slope',default=1,type=str,help='Distortion slope')
    args = parser.parse_args()
    # Check what to do
    if args.operation=='model':
        DistModel(args.model,args.explist,args.slope,args.output)
    if args.operation=='fit':
        AlphaDist(args.model,args.explist,args.distmin,args.distmax,args.distmid,args.distsep)
    if args.operation=='curve':
        PlotCurve(args.thermal,args.turbulent,args.distmin,args.distmax,args.distsep,args.xmin,args.xmax,args.output)
    if args.operation=='getshift':
        explist = numpy.genfromtxt(args.explist,names=True,skip_header=1,dtype=object)
        print(getshift(explist,args.xmin,args.xmax,float(args.slope)))
