#!/home/sam/anaconda3/bin/python

from bruce import lc
import emcee, corner, sys, os, numpy as np, math
import matplotlib.pyplot as plt 
import argparse 
from multiprocessing import Pool
from celerite.modeling import Model
from celerite import terms, GP
from scipy.stats import chisquare, sem
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
from scipy.optimize import minimize
#import astropy.units as u
#from astropy.coordinates import SkyCoord
#from astroquery.gaia import Gaia
#from isochrones import get_ichrone, SingleStarModel
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
np.warnings.filterwarnings('ignore')

color = "#ff7f0e"

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin


# Phase
def phaser(time, t_zero, period) : return ((time - t_zero)/period) - np.floor((time - t_zero)/period) 


# Define the model
transit_model_bounds = dict(radius_1 = (0, 0.9), k = (0, 0.9), b = (0,2), fs = (-1,1), fc = (-1,1), SBR = (0,None), light3 = (0, None), J = (0, None))



# Define the custom kernel
class RotationTerm(terms.Term):
    parameter_names = ("log_amp", "log_timescale", "log_period", "log_factor")

    def get_real_coefficients(self, params):
        log_amp, log_timescale, log_period, log_factor = params
        f = np.exp(log_factor)
        return (
            np.exp(log_amp) * (1.0 + f) / (2.0 + f),
            np.exp(-log_timescale),
        )

    def get_complex_coefficients(self, params):
        log_amp, log_timescale, log_period, log_factor = params
        f = np.exp(log_factor)
        return (
            np.exp(log_amp) / (2.0 + f),
            0.0,
            np.exp(-log_timescale),
            2*np.pi*np.exp(-log_period),
        )







class transitmodel(Model):
    parameter_names = ("t_zero", "period", "radius_1", "k", "fs", "fc", "b", "q", "albedo", "alpha_doppler",
                        "K1", "spots", "omega_1", "ldc_law_1","h1", "h2", "gdc_1",
                        "SBR", "light_3", "E_tol" , "zp", "J" )

    def get_value(self, t):
        ldc_1_1 =  1 - self.h1 + self.h2
        ldc_1_2 =   np.log2( ldc_1_1 / self.h2)
        return self.zp - 2.5*np.log10(lc(t, t_zero = self.t_zero, period = self.period,
                radius_1 = self.radius_1, k=self.k, 
                fs = self.fs, fc = self.fc, 
                q=self.q, albedo = self.albedo,
                alpha_doppler=self.alpha_doppler, K1 = self.K1,
                spots = np.array(self.spots), omega_1=self.omega_1,
                incl = 180*np.arccos(self.radius_1*self.b)/np.pi,
                ld_law_1=int(self.ldc_law_1), ldc_1_1 = ldc_1_1, ldc_1_2 = ldc_1_2, gdc_1 = self.gdc_1,
                SBR=self.SBR, light_3 = self.light_3,
                E_tol=self.E_tol))



    def get_log_prob_prior(self, h1_ref, h2_ref):
        return -0.5*( ((self.h1 - h1_ref)**2)/(0.003**2) +     ((self.h2 - h2_ref)**2)/(0.046**2)    )

    def log_likelihood(self, t, mag, mag_err, h1_ref, h2_ref):
        model = self.get_value(t)
        wt = 1.0 / (mag_err**2 + self.J**2)
        return -0.5*np.sum((mag - model)**2*wt - np.log(wt)) + self.get_log_prob_prior(h1_ref, h2_ref)
        
        '''
        return lc(t, mag, mag_err, J=self.J, zp=args.zp,
                t_zero = self.t_zero, period = self.period,
                radius_1 = self.radius_1, k=self.k, 
                fs = self.fs, fc = self.fc, 
                q=self.q, albedo = self.albedo,
                alpha_doppler=self.alpha_doppler, K1 = self.K1,
                spots = np.array(self.spots), omega_1=self.omega_1,
                incl = 180*np.arccos(self.radius_1*self.b)/np.pi,
                ld_law_1=int(self.ldc_law_1), ldc_1_1 = self.ldc_1_1, ldc_1_2 = self.ldc_1_2, gdc_1 = self.gdc_1,
                SBR=self.SBR, light_3 = self.light_3,
                E_tol=self.E_tol)
        '''


def lnlike(theta, time, mag, mag_err, t_zero_ref, period_ref, theta_names, h1_ref, h2_ref, minimize_switch=False ):
    # First, set the attributes 
    for i in range(len(theta)) : transit_model.set_parameter(theta_names[i], theta[i])

    # The check limits 
    if (transit_model.t_zero < t_zero_ref - 0.2*period_ref) or (transit_model.t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    #if (transit_model.t_zero < t_zero_ref - period_ref) or (transit_model.t_zero > t_zero_ref + period_ref) : return -np.inf

    #if (transit_model.period < period_ref - 1e-3) or (transit_model.period > period_ref + 1e-3) : return -np.inf 
    if (transit_model.period < period_ref - 1e-2) or (transit_model.period > period_ref + 1e-2) : return -np.inf 

    if (transit_model.k < 0.0) or (transit_model.k > 0.8) : return -np.inf 
    if (transit_model.radius_1 < 0.0) or (transit_model.radius_1 > 0.8) : return -np.inf 
    if (transit_model.b < 0) or (transit_model.b > 1.0 + transit_model.k) : return -np.inf 
    if (transit_model.J < 0) : return -np.inf 
    if (transit_model.q < 0) : return -np.inf 
    if (transit_model.zp < -20) or (transit_model.zp > 20) : return -np.inf 
    if (transit_model.fs < -0.7) or (transit_model.fs > 0.7) : return -np.inf 
    if (transit_model.fc < -0.7) or (transit_model.fc > 0.7) : return -np.inf 
    if ((transit_model.fc**2 +  transit_model.fs**2) > 0.999) : return -np.inf 

    if (transit_model.SBR < 0) or (transit_model.SBR > 1) : return -np.inf 

    # now return loglike 
    if minimize_switch : 
        print(-2*transit_model.log_likelihood(time, mag, mag_err, h1_ref, h2_ref))
        return -2*transit_model.log_likelihood(time, mag, mag_err, h1_ref, h2_ref)
    else : return transit_model.log_likelihood(time, mag, mag_err, h1_ref, h2_ref)




def lnlike_gp(theta, time, mag, mag_err, t_zero_ref, period_ref, theta_names,  h1_ref, h2_ref, minimize_switch=False ):
    # First, set the attributes 
    for i in range(len(theta)) : gp.set_parameter(theta_names[i], theta[i]) 

    t_zero = gp.get_parameter('mean:t_zero')
    period = gp.get_parameter('mean:period')
    if (t_zero < t_zero_ref - 0.2*period_ref) or (t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    if (period < period_ref - 1e-3) or (period > period_ref + 1e-3) : return -np.inf 

    # We don't need to do bounds, that should already be done 
    lp = gp.log_prior() -0.5*( ((gp.get_parameter('mean:h1') - h1_ref)**2)/(0.003**2) +     ((gp.get_parameter('mean:h2') - h2_ref)**2)/(0.046**2)    )
    if not np.isfinite(lp) : return -np.inf 

    if minimize_switch : return -(gp.log_likelihood(mag) + lp)
    else : return gp.log_likelihood(mag) + lp




# Welcom messages
welcome_message = '''---------------------------------------------------
-                   NGTSfit V.2                   -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to fit binary star observations elegantly. 
Use the -h flag to see all available options for the fit. For any questions, 
please email samuel.gill@warwick.ac.uk'''

emcee_message = '''---------------------------------------------------
-                   emcee                         -
---------------------------------------------------'''


# Argument parser
parser = argparse.ArgumentParser('ngtsfit', description=description)
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename of the binary star information')


parser.add_argument('-a', 
                    '--t_zero',
                    help='The transit epoch in arbritraty time units consisting with the input file.', 
                    default=0.0, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=1.0, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.2, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.2, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-f', 
                    '--zp',
                    help='The photometric zero-point.',
                    default=0., type=float) 

parser.add_argument('-g', 
                '--limb_darkening_law',
                help='The limb-darkening law for star 1. Options are: 1) quadratic, 2) power2 .',
                default='power2') 

parser.add_argument('-i', 
                '--h1',
                help='The first limb-darkening coefficient [default 0.5].',
                default=0.5, type=float)

parser.add_argument('-j', 
                '--h2',
                help='The second limb-darkening coefficient [default 0.2].',
                default=0.5, type=float) 


parser.add_argument('-k', '--spots', 
                    help='The information for spots on star 1, if required', 
                    nargs='+', 
                    type=float, 
                    default=[])

parser.add_argument('-l', 
                '--gdc_1',
                help='The gravity darkening coefficient of star 1 [default 0.4].',
                default=0.4, type=float) 

parser.add_argument('-m', 
                '--q',
                help='The mass ratio of star 2 to star 1 [default 0.].',
                default=0., type=float) 

parser.add_argument('-n', 
                '--albedo',
                help='The albedo of the secondary [default 0.]',
                default=0., type=float)

parser.add_argument('-o', 
                '--alpha_doppler',
                help='The alpha_doppler parameter.',
                default=0., type=float)

parser.add_argument('-p', 
                '--K1',
                help='The semi-amplitude [km/s] of radial velocity (used for ellipsoidal variation and rv) [default 10].',
                default=10, type=float)

parser.add_argument('-q', 
                '--light_3',
                help='The third light in the system [default 0.].',
                default=0., type=float)

parser.add_argument('-r', 
                '--SBR',
                help='The surface-brightness ratio [default 0.].',
                default=0., type=float)

parser.add_argument('-w', 
                '--J',
                help='The additional Jitter [default 0.].',
                default=0., type=float)

parser.add_argument('--trial', action="store_true", default=False)

parser.add_argument('-s', 
                '--plot_alpha',
                help='The plot alpha',
                default=1., type=float)


parser.add_argument('-t', '--fitpars', 
                    help='A comma seperated list of free parameters', 
                    nargs='+',
                    default=[])

parser.add_argument('--emcee', action="store_true", default=False)
parser.add_argument('--minimize', action="store_true", default=False)

parser.add_argument('-u', 
                '--emcee_steps',
                help='The number of emcee steps [default 1000]',
                default=10000, type=int)

parser.add_argument('-v', 
                '--emcee_burn_in',
                help='The number of emcee steps to discard [default 500]',
                default=5000, type=int)

parser.add_argument('-x', 
                '--threads',
                help='The number of threads to use [default 1]',
                default=10, type=int)

parser.add_argument('-y', 
        '--bin',
        help='The bin width from which to bin the lightcurve, in minutes [default=None].', 
        default=0.0, type=float)   


parser.add_argument('--gp', action="store_true", default=False)
parser.add_argument('--backend', action="store_true", default=False)


parser.add_argument('-ab', 
        '--log_amp',
        help='log_amp for GP', 
        default=-10.92420820929161, type=float)           

parser.add_argument('-ac', 
        '--log_timescale',
        help='log_timescale for GP', 
        default=5.821080194036538, type=float) 


parser.add_argument('-ax', 
        '--log_period',
        help='log_period for GP', 
        default=1.7458691443708192, type=float) 

parser.add_argument('-ay', 
        '--log_factor',
        help='log_factor for GP', 
        default=4.360312565858812, type=float) 

parser.add_argument('-z', 
        '--log_sigma',
        help='log_sigma for GP', 
        default=-4, type=float) 

parser.add_argument('-ad', 
        '--fs',
        help='fs for eccentricity = sin(omega)*root(e)', 
        default=0.0, type=float)  

parser.add_argument('-ae', 
        '--fc',
        help='fs for eccentricity = cos(omega)*root(e)', 
        default=0.0, type=float)  

parser.add_argument('-af', 
        '--omega_1',
        help='Ratio of angular rotation of the host star to orbiting body.  ', 
        default=1.0, type=float) 


parser.add_argument('-aq', 
        '--savepath',
        help='The save path [default .]', 
        default='.', type=str) 

parser.add_argument('-aw', 
        '--name',
        help='The system name [default star]', 
        default='star', type=str) 


parser.add_argument('-bc', 
                    '--R1',
                     help='The radius of star 1', type=float,
					 default=1.)
			

'''
# Emcee function 
def lnlike(theta, time, mag, mag_err, theta_names, t_zero_ref, period_ref):


    # Make a copy of the args and copy over the values
    args1 = np.copy(args).all()
    for i in range(len(theta_names)) : args1.__setattr__(theta_names[i], theta[i])

    if (args1.t_zero < t_zero_ref - 0.2*period_ref) or (args1.t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    if (args1.period < period_ref - 1e-3) or (args1.period > period_ref + 1e-3) : return -np.inf 
    if (args1.k < 0.0) or (args1.k > 0.8) : return -np.inf 
    if (args1.radius_1 < 0.0) or (args1.radius_1 > 0.8) : return -np.inf 
    if (args.b < 0) or (args.b > 1.0 + args.k) : return -np.inf 
    if (args.J < 0) : return -np.inf 
    if (args.q < 0) : return -np.inf 

    # Return the loglike
    log =  lc(time, mag=mag, mag_err=mag_err, J=args1.J, zp = args1.zp,
        t_zero = args1.t_zero, period = args1.period,
        radius_1 = args1.radius_1, k=args1.k, 
        fs = 0.0, fc = 0.0, 
        q=args1.q, albedo = args1.albedo,
        alpha_doppler=args1.alpha_doppler, K1 = args1.K1,
        spots = np.array(args1.spots), omega_1=1., nspots=nspots,
        incl = 180*np.arccos(args1.radius_1*args1.b)/np.pi,
        ldc_law_1=ld_law, ldc_1_1 = args1.ldc_1, ldc_1_2 = args1.ldc_2, gdc_1 = args1.gdc_1,
        SBR=args1.SBR, light_3 = args1.light_3,
        Accurate_t_ecl=0, t_ecl_tolerance=1e-5, Accurate_Eccentric_Anomaly=1, E_tol=1e-5,
        nthreads=1)

    if np.isnan(log) : return -np.inf 
    else : return log
'''


import math
def transit_width(r, k, b, P=1):
	"""
	Total transit duration.
	See equation (3) from Seager and Malen-Ornelas, 2003ApJ...585.1038S.
	:param r: R_star/a
	:param k: R_planet/R_star
	:param b: impact parameter = a.cos(i)/R_star
	:param P: orbital period (optional, default P=1)
	:returns: Total transit duration in the same units as P.
	"""

	return P*math.asin(r*math.sqrt( ((1+k)**2-b**2) / (1-b**2*r**2) ))/math.pi



if __name__ == "__main__":
    args = parser.parse_args()

    # Print the welcome message 
    print(welcome_message)

    # Check for a file
    if len(sys.argv) == 1 : raise ValueError('No file specified')

    # Now load the datafile
    try:
        time, mag, mag_err = np.loadtxt(args.filename).T
    except ValueError:
        time, mag, mag_err, aaaaa, aaaaaaa = np.loadtxt(args.filename).T

    # now mask 
    mask = np.isnan(mag) | np.isinf(mag) | np.isnan(mag_err) | np.isinf(mag_err)
    try:
        time, mag, mag_err = np.loadtxt(args.filename)[~mask].T
    except ValueError:
        time, mag, mag_err, aaaaa, aaaaaaa = np.loadtxt(args.filename)[~mask].T

    print('Loaded {:,} lines from {:}'.format(len(time),args.filename))
    if args.bin > 0 : 
        time, mag, mag_err = lc_bin(time, mag, args.bin/24./60.)
        print('\treduced to {:} lines with {:}-minute binning'.format(len(time), args.bin))

    time = time.astype(np.float64)
    mag = mag.astype(np.float64) 
    mag_err = mag_err.astype(np.float64)

    print('---------------------------------------------------')

    # Report 
    print('System parameters:')
    print('Name : {:}'.format(args.name))
    print('Savepath : {:}'.format(args.savepath))
    print('\tt_zero   : {:}'.format(args.t_zero))
    print('\tperiod   : {:}'.format(args.period))
    print('\tradius_1 : {:}'.format(args.radius_1))
    print('\tk        : {:}'.format(args.k))
    print('\tb        : {:} [{:.2f} deg]'.format(args.b, 180*np.arccos(args.radius_1*args.b)/np.pi))
    print('\tzp       : {:}'.format(args.zp))
    print('\tld_law   : {:}'.format(args.limb_darkening_law))
    print('\t\t   -------')
    print('\t\t   h1 {:}'.format(args.h1))
    print('\t\t   h2 {:}'.format(args.h2)) 
    print('\t\t   gdc_1 {:}'.format(args.gdc_1)) 
    nspots = len(args.spots)//4
    print('\tspots    : {:}'.format(nspots))
    if (nspots > 0):
        for i in range(nspots):
            print('\t\t   Spot ', i, '\n\t\t   -------')
            print('\t\t   longitude of spot centre (radians) = {:}'.format(args.spots[4*i + 0]))
            print('\t\t   latitude of spot centre (radians)  = {:}'.format(args.spots[4*i + 1]))
            print('\t\t   angular radius of spot (radians)   = {:}'.format(args.spots[4*i + 2]))
            print('\t\t   Spot contrast ratio (a=Is/Ip).     = {:}'.format(args.spots[4*i + 3])) 
    print('\tq      : {:}'.format(args.q)) 
    print('\talbedo : {:}'.format(args.albedo)) 
    print('\talpha  : {:}'.format(args.alpha_doppler)) 
    print('\tK1     : {:}'.format(args.K1)) 
    print('\tfs     : {:}'.format(args.fs)) 
    print('\tfc    : {:}'.format(args.fc)) 
    print('\tomega_1    : {:}'.format(args.omega_1)) 
    print('\tlight_3: {:}'.format(args.light_3)) 
    print('\tsbr    : {:}'.format(args.SBR))
    print('\ttrial  : {:}'.format(args.trial))
    print('\tFree parameters ({:}):'.format(len(args.fitpars)))
    for i in range(len(args.fitpars)):
        print('\t\t{:}'.format(args.fitpars[i]))
    print('\tGP : {:}'.format(args.gp))
    print('\t\tlog_amp : {:}'.format(args.log_amp))
    print('\t\tlog_timescale : {:}'.format(args.log_timescale))
    print('\t\tlog_period : {:}'.format(args.log_period))
    print('\t\tlog_factor : {:}'.format(args.log_sigma))
    print('\t\tlog_sigma : {:}'.format(args.log_factor))

    print('\tThreads  : {:}'.format(args.threads))

    if args.limb_darkening_law =='uniform' : ld_law = 0
    if args.limb_darkening_law =='quadratic' : ld_law = 1
    if args.limb_darkening_law =='power2'    : ld_law = 2

    # Now let's do a trial if needed
    # First, let's initialse the transit model 
    transit_model = transitmodel(t_zero = args.t_zero, period = args.period, radius_1 = args.radius_1, k=args.k, 
                                fs=args.fs,fc = args.fc, b = args.b, q = args.q, albedo = args.albedo, alpha_doppler = args.alpha_doppler,
                                K1 = args.K1, spots=  np.array(args.spots), omega_1=1.0, ldc_law_1 = ld_law,
                                h1 = args.h1, h2 = args.h2, gdc_1 = args.gdc_1, SBR  = args.SBR, 
                                light_3 = args.light_3, 
                                E_tol = 1e-4, zp = args.zp, bounds=transit_model_bounds, J=args.J)


    if args.gp : 
        kernel = RotationTerm(
            log_amp=args.log_amp,
            log_timescale=args.log_timescale,
            log_period=args.log_period,
            log_factor=args.log_factor,
            bounds=dict(
                log_amp=(-15.0, -0.1),
                log_timescale=(0.1, 10.0),
                log_period=(0.01, 8.0),
                log_factor=(-30, 20.0),
            ),
        )

        kernel += terms.JitterTerm(
            log_sigma=args.log_sigma,
            bounds=[(-20, -0.1)],
        )


        gp = GP(kernel, mean=transit_model, fit_mean=True)
        gp.compute(time, mag_err)





                         
    if args.trial:
        # Plotting pre-processing 

        if args.gp : 
            mu, var = gp.predict(mag, time, return_var=True)
            std = np.sqrt(var) 

            # First, plot the model
            fig = plt.figure(figsize=(15,10))


            spec = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)
            ax1 = fig.add_subplot(spec[0, :])
            ax2 = fig.add_subplot(spec[1, 0])
            ax3 = fig.add_subplot(spec[1, 1])


            ax1.scatter(time-time[0], mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.fill_between(time-time[0], mu+std, mu-std, color=color, alpha=0.3, edgecolor="none")
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax1.set_xlabel('JD -{:}'.format(time[0]))
            #ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            # Then plot the data
            detrended =  mag - mu 
            phase = phaser(time ,args.t_zero, args.period) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0) 
            transit_model.set_parameter('zp', 0.)
            ax2.scatter(phase, detrended + transit_model.get_value(phase), c='k', s=10)
            ax2.scatter(phase-1, detrended + transit_model.get_value(phase), c='k', s=10)

            phase_time = np.linspace(-0.2,1,10000)
            #ax2.plot(phase_time, transit_model.get_value(phase_time), 'r')
            ax2.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")
            width = transit_width(args.radius_1, args.k, args.b, P=1)
            ax2.set_xlim(-width,width)
            ax2.set_ylabel('Mag')
            ax2.set_xlabel('Phase')


            ax3.scatter(phase, detrended + transit_model.get_value(phase), c='k', s=10)
            ax3.scatter(phase-1, detrended + transit_model.get_value(phase), c='k', s=10)

            phase_time = np.linspace(-0.2,1,10000)
            #ax2.plot(phase_time, transit_model.get_value(phase_time), 'r')
            ax3.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")
            ax3.set_xlim(0.4,0.6)
            ax3.set_ylabel('Mag')
            ax3.set_xlabel('Phase')

            # For third light
            depth = np.max(transit_model.get_value(phase)) - np.min(transit_model.get_value(phase)) + 2*np.median(std) # in mag
            ax2.set_ylim(1.5*depth, -0.5*depth)
            ax3.set_ylim(1.5*depth, -0.5*depth)

            plt.tight_layout()
            plt.savefig('{:}/{:}_trial_gp.pdf'.format(args.savepath, args.name))
            plt.close()


            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)
            transit_model.set_parameter('zp', args.zp)


        else:
            # First, plot the model
            fig = plt.figure(figsize=(15,10))
            spec = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)
            ax1 = fig.add_subplot(spec[0, :])
            ax2 = fig.add_subplot(spec[1, 0])
            ax3 = fig.add_subplot(spec[1, 1])


            ax1.scatter(time-time[0], mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax1.set_xlabel('JD -{:}'.format(time[0]))
            #ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)

            phase = phaser(time ,args.t_zero, args.period) 
            phase_time = np.linspace(-0.2,0.8, 10000)
            ax2.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            ax2.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            ax2.plot(phase_time, transit_model.get_value(phase_time), color)
            width = transit_width(args.radius_1, args.k, args.b, P=1)
            
            if args.q != 0 : ax2.set_xlim(-0.2,0.8)
            else : ax2.set_xlim(-width,width)
            ax2.set_xlabel('Phase')
            ax2.set_ylabel('Mag')

            phase_time = np.linspace(-0.2,1,10000)
            centre = 0.5
            ax3.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            ax3.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            ax3.set_xlim(centre - width, centre+width)
            ax3.plot(phase_time, transit_model.get_value(phase_time), color)
            ax3.set_ylabel('Mag')
            ax3.set_xlabel('Phase')

            # For third light
            depth = np.max(transit_model.get_value(phase_time)) - np.min(transit_model.get_value(phase_time)) # in mag
            ax2.set_ylim(transit_model.get_parameter('zp') + 1.5*depth,transit_model.get_parameter('zp') -0.5*depth)
            ax3.set_ylim(transit_model.get_parameter('zp') + 1.5*depth,transit_model.get_parameter('zp') -0.5*depth)


            #plt.title('$\\chi^2_r$ : {:.6}'.format(-2*transit_model.log_likelihood(time, mag, mag_err)/len(time)))
            plt.tight_layout()
            plt.savefig('{:}/{:}_trial.pdf'.format(args.savepath, args.name))
            plt.show()

            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)


    if args.emcee:
        # first, let's validat arguments 
        print(emcee_message)
        ndim = len(args.fitpars)
        for i in range(len(args.fitpars)):
            if not hasattr(args, args.fitpars[i]) : raise ValueError('Parameter "{:}" is not a valid identifier.'.format(args.fitpars[i]))

        nwalkers = 4*ndim 
        theta = []
        for i in range(len(args.fitpars)) : theta.append(float(eval('args.{:}'.format(args.fitpars[i]))))
        p0 = np.array([np.random.normal(theta, 1e-5).tolist() for i in range(nwalkers)]) 

        # Set up the backend
        # Don't forget to clear it in case the file already exists
        if args.backend:
            if args.gp : filename = '{:}/{:}_emcee_output_gp.h5'.format(args.savepath, args.name)
            else : filename = '{:}/{:}_emcee_output.h5'.format(args.savepath, args.name)

            backend = emcee.backends.HDFBackend(filename)
            backend.reset(nwalkers, ndim)
        else:
            backend = None

        with Pool(int(args.threads)) as pool:
            if not args.gp : 
                sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike, args = (time, mag, mag_err, args.t_zero, args.period, args.fitpars, args.h1, args.h2), backend=backend, pool=pool)

            else:
                for i in range(len(args.fitpars)):
                    if (args.fitpars[i] == 'log_amp') or (args.fitpars[i] == 'log_timescale') or (args.fitpars[i] == 'log_period') or (args.fitpars[i] == 'log_factor') :
                         args.fitpars[i] =  'kernel:terms[0]:' + args.fitpars[i]
                    elif (args.fitpars[i] == 'log_sigma'):
                         args.fitpars[i] =  'kernel:terms[1]:' + args.fitpars[i]
                    else : args.fitpars[i] =  'mean:'   + args.fitpars[i]

                sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike_gp, args = (time, mag, mag_err, args.t_zero, args.period, args.fitpars, args.h1, args.h2), backend=backend, pool=pool)
            sampler.run_mcmc(p0, args.emcee_steps, progress=True) 

        fig_chain, axes = plt.subplots(ndim, figsize=(6, 3*ndim))
        samples = sampler.get_chain()
        for i in range(ndim):
            ax = axes[i] 
            ax.semilogx(samples[:,:,i], 'k', alpha = 0.3)
            ax.set_xlim(0,len(samples))
            ax.set_ylabel(args.fitpars[i]) 
        fig_chain.tight_layout()
        if args.gp : fig_chain.savefig('{:}/{:}_chains_gp.pdf'.format(args.savepath, args.name))
        else       : fig_chain.savefig('{:}/{:}_chains.pdf'.format(args.savepath, args.name))

        plt.close(fig_chain)


        samples = sampler.get_chain(flat=True, discard=args.emcee_burn_in)
        logs = sampler.get_log_prob(flat=True, discard=args.emcee_burn_in) 

        best_idx = np.argmax(logs) 
        best_step = samples[best_idx] 
        low_err = best_step - np.percentile(samples, 16, axis=0)
        high_err = np.percentile(samples, 84, axis=0) - best_step

        print('Best result:')
        if args.gp : output_file = open('{:}/{:}_results_gp.txt'.format(args.savepath, args.name), 'w') 
        else : output_file = open('{:}/{:}_results.txt'.format(args.savepath, args.name), 'w') 
        for i in range(ndim) : 
            print('{:>15} = {:.5f} + {:.5f} - {:.5f}'.format(args.fitpars[i], best_step[i], high_err[i], low_err[i]))
            output_file.write('{:>15} {:.5f} {:.5f} {:.5f}\n'.format(args.fitpars[i], best_step[i], high_err[i], low_err[i]))
        output_file.close() 


        # now make the corner
        fig_corner = corner.corner(samples, labels=args.fitpars, truths = best_step)
        if args.gp : fig_corner.savefig('{:}/{:}_corner_gp.pdf'.format(args.savepath, args.name))
        else :       fig_corner.savefig('{:}/{:}_corner.pdf'.format(args.savepath, args.name))
        plt.close(fig_corner)


        # Now get the best model 
        if args.gp:
            # First, set the parameter
            for i in range(len(args.fitpars)):
                gp.set_parameter(args.fitpars[i], best_step[i]) 
                if 'mean:' in args.fitpars[i] :
                    transit_model.set_parameter(args.fitpars[i][5:], best_step[i]) 
            transit_model.set_parameter('zp', 0.)

            
            mu, var = gp.predict(mag, time, return_var=True)
            std = np.sqrt(var) 

            # First, plot the model
            fig = plt.figure(figsize=(15,10))


            spec = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)
            ax1 = fig.add_subplot(spec[0, :])
            ax2 = fig.add_subplot(spec[1, 0])
            ax3 = fig.add_subplot(spec[1, 1])


            ax1.scatter(time-time[0], mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.fill_between(time-time[0], mu+std, mu-std, color=color, alpha=0.3, edgecolor="none")
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax1.set_xlabel('JD -{:}'.format(time[0]))
            #ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            # Then plot the data
            detrended =  mag - mu 
            phase = phaser(time ,args.t_zero, args.period) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0) 
            transit_model.set_parameter('zp', 0.)
            ax2.scatter(phase, detrended + transit_model.get_value(phase), c='k', s=10)
            ax2.scatter(phase-1, detrended + transit_model.get_value(phase), c='k', s=10)

            phase_time = np.linspace(-0.2,1,10000)
            #ax2.plot(phase_time, transit_model.get_value(phase_time), 'r')
            ax2.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")
            width = transit_width(gp.get_parameter('mean:radius_1'), gp.get_parameter('mean:k'),gp.get_parameter('mean:b'), P=1)
            ax2.set_xlim(-width,width)
            ax2.set_ylabel('Mag')
            ax2.set_xlabel('Phase')


            ax3.scatter(phase, detrended + transit_model.get_value(phase), c='k', s=10)
            ax3.scatter(phase-1, detrended + transit_model.get_value(phase), c='k', s=10)

            phase_time = np.linspace(-0.2,1,10000)
            #ax2.plot(phase_time, transit_model.get_value(phase_time), 'r')
            ax3.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")
            ax3.set_xlim(0.4,0.6)

            ax3.set_ylabel('Mag')
            ax3.set_xlabel('Phase')

            # For third light
            depth = np.max(transit_model.get_value(phase)) - np.min(transit_model.get_value(phase)) + 2*np.median(std) # in mag

            ax2.set_ylim(1.5*depth, -0.5*depth)
            ax3.set_ylim(1.5*depth, -0.5*depth)

            plt.tight_layout()
            plt.savefig('{:}/{:}_best_gp.pdf'.format(args.savepath, args.name))
            plt.close()
    
            # Reset
            # First, set the parameter
            for i in range(len(args.fitpars)):
                gp.set_parameter(args.fitpars[i], best_step[i]) 
                if 'mean:' in args.fitpars[i] :
                    transit_model.set_parameter(args.fitpars[i][5:], best_step[i]) 
            transit_model.set_parameter('zp', 0.)

        else:
            # First, set the parameter
            for i in range(len(args.fitpars)):
                transit_model.set_parameter(args.fitpars[i], best_step[i]) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)

            phase = phaser(time ,args.t_zero, args.period) 
            phase_time = np.linspace(-0.2,0.8, 10000)

            # First, plot the model
            fig = plt.figure(figsize=(15,10))
            spec = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)
            ax1 = fig.add_subplot(spec[0, :])
            ax2 = fig.add_subplot(spec[1, 0])
            ax3 = fig.add_subplot(spec[1, 1])


            ax1.scatter(time-time[0], mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax1.set_xlabel('JD -{:}'.format(time[0]))
            #ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)

            phase = phaser(time ,args.t_zero, args.period) 
            phase_time = np.linspace(-0.2,0.8, 10000)
            ax2.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            ax2.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            ax2.plot(phase_time, transit_model.get_value(phase_time), color)
            width = transit_width(transit_model.get_parameter('radius_1'), transit_model.get_parameter('k'),transit_model.get_parameter('b'), P=1)
            if args.q != 0 : ax2.set_xlim(-0.2,0.8)
            else : ax2.set_xlim(-width,width)
            ax2.set_xlabel('Phase')
            ax2.set_ylabel('Mag')

            phase_time = np.linspace(-0.2,1,10000)
            centre = 0.5
            ax3.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            ax3.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            ax3.set_xlim(centre - width, centre+width)
            ax3.plot(phase_time, transit_model.get_value(phase_time), color)
            ax3.set_ylabel('Mag')
            ax3.set_xlabel('Phase')

            # For third light
            depth = np.max(transit_model.get_value(phase_time)) - np.min(transit_model.get_value(phase_time)) # in mag
            ax2.set_ylim(transit_model.get_parameter('zp') + 1.5*depth,transit_model.get_parameter('zp') -0.5*depth)
            ax3.set_ylim(transit_model.get_parameter('zp') + 1.5*depth,transit_model.get_parameter('zp') -0.5*depth)

    

            #plt.title('$\\chi^2_r$ : {:.6}'.format(-2*transit_model.log_likelihood(time, mag, mag_err)/len(time)))
            plt.tight_layout()
            plt.savefig('{:}/{:}_best.pdf'.format(args.savepath, args.name))
            plt.close()

            for i in range(len(args.fitpars)):
                transit_model.set_parameter(args.fitpars[i], best_step[i]) 






        # Now let's do the night plots 
        phase = phaser(time ,args.t_zero, args.period) 
        if args.gp : width = transit_width(gp.get_parameter('mean:radius_1'), gp.get_parameter('mean:k'),gp.get_parameter('mean:b'), P=1)
        else :       width = transit_width(transit_model.get_parameter('radius_1'), transit_model.get_parameter('k'),transit_model.get_parameter('b'), P=1)
        mask = (phase < (width) ) | (phase > (1 - width) )
        mask = mask
        grad = np.gradient(time[mask])
        dt_, dt__, dm_, dm__ = [],[],[],[]
        for i in range(grad.shape[0]):
            if len(dt__)==0 : 
                dt__.append(time[mask][i])
                if args.gp : dm__.append(mag[mask][i] - mu[mask][i])
                else : dm__.append(mag[mask][i])
            else:
                if (grad[i] < 0.4) :
                    dt__.append(time[mask][i])
                    if args.gp : dm__.append(mag[mask][i] - mu[mask][i])
                    else : dm__.append(mag[mask][i])
                else:
                    if len(dt__) > 1:
                        dt_.append(dt__)
                        dm_.append(dm__) 
                    dt__ = []
                    dm__ = []
        if len(dt__) > 0 : 
            dt_.append(dt__)
            dm_.append(dm__)       

        if args.gp : width = transit_width(gp.get_parameter('mean:radius_1'), gp.get_parameter('mean:k'),gp.get_parameter('mean:b'), P=gp.get_parameter('mean:period'))
        else :       width = transit_width(transit_model.get_parameter('radius_1'), transit_model.get_parameter('k'),transit_model.get_parameter('b'), P=transit_model.get_parameter('period'))


        number_of_nights = len(dt_)
        print('Nights : ', number_of_nights)
        if number_of_nights > 0:
            f, axs = plt.subplots(figsize=(7,number_of_nights*3), ncols=1, nrows = number_of_nights)
            if number_of_nights > 1 :
                for i in range(number_of_nights):
                    print('Night ', i)
                    time__ = np.linspace(dt_[i][0] - width/2, dt_[i][-1] + width/2, 1000)

                    if args.gp : 
                        axs[i].scatter(dt_[i], dm_[i] + transit_model.get_value(np.array(dt_[i])), c='k', s=10, alpha=args.plot_alpha)
                        axs[i].fill_between(time__, transit_model.get_value(time__)-np.median(std), transit_model.get_value(time__)+np.median(std), color=color, alpha=0.3, edgecolor="none")
                    else       : 
                        axs[i].scatter(dt_[i], dm_[i], c='k', s=10, alpha=args.plot_alpha)
                        axs[i].plot(time__, transit_model.get_value(time__), color)

                    axs[i].set_ylabel('Mag')
                    axs[i].invert_yaxis()
                    axs[i].set_title('Night {:}'.format(i+1))
                    axs[i].set_xticks([])

            else : 
                axs.scatter(dt_, dm_, c='k', s=10, alpha=args.plot_alpha)
                ax2.set_xlabel('Time')
                ax2.set_ylabel('Mag')   

        
            plt.tight_layout()
            if args.gp : plt.savefig('{:}/{:}_nights_gp.pdf'.format(args.savepath, args.name))
            else : plt.savefig('{:}/{:}_nights.pdf'.format(args.savepath, args.name))
            plt.close()

        # Now let's do odd/even
        # First, plot the model
        fig = plt.figure(figsize=(15,5))
        spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig)
        ax1 = fig.add_subplot(spec[0, 0])
        ax2 = fig.add_subplot(spec[0, 1])

        width2 = transit_width(args.radius_1, args.k, args.b, P=args.period*2)/(args.period*2)
        phase1 = phaser(time ,args.t_zero, args.period*2) 
        mask1 = (phase1 < (width2 * 2)) | (phase1 >  (1 - (width2 * 2)))
        mask2 = (phase1 > (0.5 - width2 * 2)) & (phase1 <  ( 0.5 + width2 * 2))
        phase = phaser(time ,args.t_zero, args.period) 

        if args.gp:
            # Then plot the data
            detrended =  mag - mu 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0) 
            transit_model.set_parameter('zp', 0.)

            ax1.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")
            ax2.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")

            ax1.scatter(phase[mask1], (detrended + transit_model.get_value(phase))[mask1], c='k', s=10,alpha=args.plot_alpha)
            ax1.scatter(phase[mask1]-1, (detrended + transit_model.get_value(phase))[mask1], c='k', s=10,alpha=args.plot_alpha)

            ax2.scatter(phase[mask2], (detrended + transit_model.get_value(phase))[mask2], c='k', s=10,alpha=args.plot_alpha)
            ax2.scatter(phase[mask2]-1, (detrended + transit_model.get_value(phase))[mask2], c='k', s=10,alpha=args.plot_alpha)

            phase_time = np.linspace(-0.2,1,10000)
            ax1.set_xlim(-width,width)
            ax2.set_xlim(-width,width)



            # For third light
            depth = np.max(transit_model.get_value(phase)) - np.min(transit_model.get_value(phase)) # in mag
            ax1.set_ylim(1.5*depth, -0.5*depth)
            ax2.set_ylim(1.5*depth, -0.5*depth)

            ax1.set_ylabel('Mag')
            ax1.set_xlabel('Even phase')
            ax2.set_xlabel('Odd phase')

            ax1.set_title('Even phase with {:} data points'.format( phase[mask1][(phase[mask1] < width/2) | (phase[mask1] < (1 - width/2))].shape[0] ))
            ax2.set_title('Odd phase with {:} data points'.format( phase[mask2][(phase[mask2] < width/2) | (phase[mask2] < (1 - width/2))].shape[0] ))
            plt.tight_layout()
            plt.savefig('{:}/{:}_odd_even_gp.pdf'.format(args.savepath, args.name))
            plt.close()

        else:
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0) 
            transit_model.set_parameter('zp', 0.)

            phase_time = np.linspace(-0.2,0.8, 10000)
            ax1.scatter(phase[mask1], mag[mask1],   c='k', s=10, alpha=args.plot_alpha)
            ax1.scatter(phase[mask1]-1, mag[mask1], c='k', s=10, alpha=args.plot_alpha)
            ax2.scatter(phase[mask2], mag[mask2],   c='k', s=10, alpha=args.plot_alpha)
            ax2.scatter(phase[mask2]-1, mag[mask2], c='k', s=10, alpha=args.plot_alpha)

            ax1.plot(phase_time, transit_model.get_value(phase_time), color)
            ax2.plot(phase_time, transit_model.get_value(phase_time), color)

            ax1.set_xlim(-width,width)
            ax2.set_xlim(-width,width)

            # For third light
            depth = np.max(transit_model.get_value(phase_time)) - np.min(transit_model.get_value(phase_time)) # in mag
            ax1.set_ylim(1.5*depth, -0.5*depth)
            ax2.set_ylim(1.5*depth, -0.5*depth)

            ax1.set_ylabel('Mag')
            ax1.set_xlabel('Even phase')
            ax2.set_xlabel('Odd phase')

            ax1.set_title('Even phase with {:} data points'.format( phase[mask1][(phase[mask1] < width/2) | (phase[mask1] < (1 - width/2))].shape[0] ))
            ax2.set_title('Odd phase with {:} data points'.format( phase[mask2][(phase[mask2] < width/2) | (phase[mask2] < (1 - width/2))].shape[0] ))
            plt.tight_layout()
            plt.savefig('{:}/{:}_odd_even.pdf'.format(args.savepath, args.name))
            plt.close()



    elif args.minimize:
        theta = []
        for i in range(len(args.fitpars)) : theta.append(float(eval('args.{:}'.format(args.fitpars[i]))))



        if args.gp:
            pass 

        else :
            bounds = [[args.t_zero - 0.1*args.period, args.t_zero + 0.1*args.period],
                    [args.period - 0.1*args.period, args.period + 0.1*args.period]]

            res = minimize(lnlike, theta, method='SLSQP', options={'eps' : 1e-5, 'gtol' : 100000},
                            args = (time, mag, mag_err, args.t_zero, args.period, args.fitpars, args.h1, args.h2, False),
                            bounds=bounds)

            print(res)

            print('Start = ', lnlike(theta, time, mag, mag_err, args.t_zero, args.period, args.fitpars, args.h1, args.h2, True))
            print(' Best = ', lnlike(res.x, time, mag, mag_err, args.t_zero, args.period, args.fitpars, args.h1, args.h2, True))



    exit()


    # Now do third light corrections
    third_lights = np.linspace(0,1,1000)
    ks = np.empty_like(third_lights)

    for i in range(0,third_lights.shape[0]):
        transit_model.set_parameter('light_3', third_lights[i])
        diff = -1
        while diff < 0:
            transit_model.set_parameter('k', transit_model.get_parameter('k') + 1e-4)
            diff = transit_model.get_value(np.array([0.]))[0] - depth
        ks[i] = transit_model.get_parameter('k')
        transit_model.set_parameter('k',ks[0])

    plt.plot(100*third_lights, ks, 'k')

    plt.gca().xaxis.set_major_locator(MultipleLocator(10))
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%d'))

    # For the minor ticks, use no labels; default NullFormatter.
    plt.gca().xaxis.set_minor_locator(MultipleLocator(2))


    plt.grid(b=True, which='major', color='k', linestyle='-', alpha = 0.5)
    plt.grid(b=True, which='minor', color='k', linestyle='--', alpha = 0.2)
    plt.xlabel('Dilution [%]')
    plt.ylabel('k = R$_2$ / R$_1$')


    ax2 = plt.gca().twinx()
    ax2.plot(100*third_lights, ks*args.R1, 'k')
    ax2.set_ylabel('R$_2$ [R$_\odot$, assuming R$_1$ = ' + '{:.2f}]'.format(args.R1))
    plt.tight_layout()
    if args.gp : plt.savefig('{:}/{:}_third_light_gp.pdf'.format(args.savepath, args.name))
    else : plt.savefig('{:}/{:}_third_light.pdf'.format(args.savepath, args.name))
    #plt.show()