#!python

# standard Imports
import matplotlib.pyplot as plt
from astropy.io import fits 
import numpy as np 
import sys,os, math
import numba
import logging
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia
from astropy.table import Table
import emcee ,corner
from scipy.stats import sem
from astropy import constants
from lockfile import LockFile
import os.path
import time as time_package
from pycheops import funcs 
from pycheops.ld import stagger_claret_interpolator
ld_interpolater = stagger_claret_interpolator(passband='NGTS')
import matplotlib.lines as lines

import argparse 
from multiprocessing import Pool
from bruce.binarystar import _lc, lc

from contextlib import contextmanager
from collections import defaultdict
import pymysql


Rjup = 69911 # km 
Rsaturn = 58232 # km 
Rneptune = 24622 # km
Rearth = 6356


# Argarse 
description = '''Fitpipe for NGTS'''
parser = argparse.ArgumentParser('fitpipe', description=description)

parser.add_argument('-a', 
                    '--fitsfile',
                     help='The orion output file that is in need of fitting.', type=str)


parser.add_argument('-b', 
                    '--savedir',
                     help='The save directory. Here a folder will be created for each prodID containing the data files and output plot.', type=str)


parser.add_argument('-c', 
                    '--nsteps',
                     help='The number of draws to make, total.', type=int)

parser.add_argument('-d', 
                    '--burn_in',
                     help='The number of draws to discard.', type=int)


parser.add_argument('-e', 
                    '--exoplanetdata',
                     help='The path to the exoplanet data file.', type=str)




@contextmanager
def open_db(host='ngtsdb', db='ngts_archive', user='u1870241', cur_class='list'):
    """
    Reusable database connection manager
    """
    if cur_class == 'list':
        with pymysql.connect(host=host,
                             db=db,
                             user=user) as cur:
            yield cur
    else:
        with pymysql.connect(host=host,
                             db=db,
                             user=user,
                             cursorclass=pymysql.cursors.DictCursor) as cur:
            yield cur 


def query_ra_dec_from_objid_prodid(prod_id, obj_id):
    qry = 'SELECT cat.ra_deg, cat.dec_deg FROM ngts_archive.orion_runs AS orun LEFT JOIN ngts_archive.catalogue AS cat ON orun.cat_prod_id = cat.cat_prod_id WHERE orun.prod_id = {:} AND cat.obj_id = {:} ;'.format(prod_id, obj_id)

    with open_db(cur_class='list') as cur: 
        cur.execute(qry) 
        results = cur.fetchall() 
        return results[0]

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin

@numba.njit
def lnlike(theta, time, mag, mag_err, period_ref, t_zero_ref, h1_ref, h2_ref, third_light):
    # Unpack
    t_zero, period, radius_1, k, b, h1, h2, zp, J = theta 
    #print(t_zero,t_zero_ref)
    #print(period, period_ref)
    # check limits 
    if (t_zero < (t_zero_ref - 1e-2)) or (t_zero > (t_zero_ref + 1e-2)) : return -np.inf
    if (period < (period_ref - 1e-3)) or (period > (period_ref + 1e-3)) : return -np.inf
    if (radius_1<0) or (radius_1 > 0.9) : return -np.inf 
    if (k<0) or (k > 0.9) : return -np.inf 
    if (b<0) or (b > (1+k)) : return -np.inf
    if (h1 < (h1_ref - 0.1)) or (h1 > (h1_ref + 0.1)) : return -np.inf
    if (h2 < (h2_ref - 0.1)) or (h2 > (h2_ref + 0.1)) : return -np.inf
    if (zp < -0.5) or (zp > 0.5) : return -np.inf
    if (J<0) : return -np.inf 
    

    # Now convert 
    incl = np.arccos(b*radius_1)
    ldc_1 = 1 - h1 + h2 
    ldc_2 = np.log2(ldc_1/h2) 

    lp = -0.5*(((h1-h1_ref)**2) / (9e-6) ) -0.5*(((h2-h2_ref)**2) / (0.002116) )
    return _lc(time, mag, mag_err, J, zp,
        t_zero, period,
        radius_1, k ,
        0, 0, 
        0., 0.,
        0., 0.,
        np.array([0]), 0., 
        incl,
        2, ldc_1, ldc_2, 0.4,
        0., third_light,
        0.,[0],
        1e-5,
        1 ) + lp


def ngts_aperture_source_search(prod_id, obj_id):
    # Quary mysql data base for the RA and dec 
    # ra, dec = sql_query() 

    ra, dec = 84.16658111680, -1.42465678997 # placeholde for now
    #ra, dec = query_ra_dec_from_objid_prodid(prod_id, obj_id)

    coord = SkyCoord(ra=ra, dec=dec, unit=(u.degree, u.degree), frame='icrs')
    width = u.Quantity(0.00416667, u.deg) # width and height of 3 pixels @ 5" / pix
    height = u.Quantity(0.00416667, u.deg)

    # Make the query
    r = Gaia.query_object_async(coordinate=coord, width=width, height=height)
    #print(r.colnames)
    third_light = 0. 

    # Now work out third light 
    if len(r) < 1 : 
        third_light = 0.
    else:
        target_flux = 10**(-0.4*r['phot_g_mean_mag'][0])
        comp_flux = np.copy(target_flux) 
        for i in range(1, len(r)) : comp_flux += 10**(-0.4*r['phot_g_mean_mag'][i]) 
        #print(target_flux, comp_flux, 1 - target_flux/comp_flux)
        third_light = 1 - target_flux/comp_flux

    star_radius = -99
    star_radius_err = -99
    star_teff = -99
    star_teff_err = -99

    if len(r) > 0 : 
        star_radius = r['radius_val'][0] 
        star_radius_err = np.max([abs(r['radius_percentile_lower'][0]-star_radius) ,abs(r['radius_percentile_upper'][0] -star_radius) ] )
        star_teff = r['teff_val'][0] 
        star_teff_err = np.max([abs(r['teff_percentile_lower'][0]-star_teff) ,abs(r['teff_percentile_upper'][0] -star_teff) ] )

    return star_radius, star_radius_err, star_teff, star_teff_err, third_light
    



#def fit_ngts_target(i, lightcurves, prod_id, obj_ids, periods, epochs, widths, depths, radius_1s, ks, PEAKS, power2_table,LC_IDX):
def fit_ngts_target(i):

    # Get the lightcurve data
    time, mag, mag_err = lightcurves.data[LC_IDX[i]]

    # Convert time to HJD from 
    time = time/ 86400. + 6658.5 

    # Bin lc to 10 mins
    time, mag, mag_err = lc_bin(time, mag, 30/24/60)
    phase = ((time - epochs[i])/periods[i]) - np.floor((time - epochs[i])/periods[i]) 
    mask = ~((phase > 0.3) & (phase < 0.7))
    
    phase = ((time -epochs[i])/ periods[i] ) - np.floor((time -epochs[i])/ periods[i] ) 
    #plt.scatter(phase, mag, c='k', s=10)
    #plt.scatter(phase-1, mag, c='k', s=10)
    #plt.xlim(-0.1, 0.1)
    #plt.gca().invert_yaxis()

    #print(epochs[i], periods[i])
    #plt.close()

    # Query the coords to get an estimate of third light and get radius_1 
    star_radius, star_radius_err, star_teff, star_teff_err, third_light = ngts_aperture_source_search(prod_id, obj_ids[i])

    #print(star_radius, star_radius_err, star_teff, star_teff_err, third_light)

    # Now need to query 
    #print(star_radius, star_teff, third_light)

    # now interpolate power2 table for h1 and h2 
    c1, c2, h1_ref, h2_ref = ld_interpolater(star_teff, 4.5, 0.0)
    #h1_ref = np.interp(star_teff, power2_table['Teff'], power2_table['h1'], left = power2_table['h1'][0], right = power2_table['h1'][-1]) 
    #h2_ref = np.interp(star_teff, power2_table['Teff'], power2_table['h2'], left = power2_table['h2'][0], right = power2_table['h2'][-1]) 


    # Get thetas for planet fit and EB fit
    theta = np.array([epochs[i], periods[i], radius_1s[i], ks[i], 0.1, h1_ref, h2_ref, 0., 0.1])
    ndim = len(theta)
    nwalkers = 4*ndim 
    p0 = theta + + 1e-6 * np.random.randn(nwalkers, ndim) 

    # set up the samplers for EB and planet 
    backend = emcee.backends.HDFBackend('{:}/{:}/{:}_{:}_backend.h5'.format(args.savedir, prod_id, obj_ids[i],PEAKS[i]))
    backend.reset(nwalkers, ndim)
    sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike, args=[time, mag, mag_err, periods[i], epochs[i], h1_ref, h2_ref, third_light], backend=backend)

    ####################################
    # Planet
    ####################################
    sampler.run_mcmc(p0, args.nsteps, progress=False)
    samples = sampler.get_chain(flat=True, discard=args.burn_in) 
    logs = sampler.get_log_prob(flat=True, discard=args.burn_in) 
    best_idx = np.argmax(logs) 
    best_step = samples[best_idx] 


    # First, do the corner
    try:
        corner_labels = ['\n\nT$_0$ [jd]', '\n\nP [d]', '\n\nR$_*$/a', '\n\nR$_2$/R$_*$', '\n\nb', '\n\nh$_1$', '\n\nh$_2$', '\n\n$z_p$', '\n\n$\sigma$ [mag.]']
        fig_corner_planet =  corner.corner(samples, labels = corner_labels, truths=best_step, quantiles=(0.16, 0.84), levels=(1-np.exp(-0.5),), smooth=1, plot_contours=False) 
        fig_corner_planet.subplots_adjust(left=0.04, bottom=0.1, right=0.997, top=1, wspace=0, hspace=0)
        fig_corner_planet.savefig('{:}/{:}/{:}_{:}_corner.png'.format(args.savedir, prod_id, obj_ids[i], PEAKS[i]))
    except : pass
    plt.close()

    # Now do the plot 
    fig_planet_lc, ax_planet_lc = plt.subplots(nrows=2, ncols=1, figsize=(5,10))
    t_zero, period, radius_1, k, b, h1, h2, zp, J = best_step
    phase = ((time-t_zero)/period) - np.floor((time-t_zero)/period)
    incl = 180*np.arccos(b*radius_1)/np.pi
    ldc_1 = 1 - h1 + h2 
    ldc_2 = np.log2(ldc_1/h2) 

    # Now get the planet calculated parameters
    depth_lc = 2.5*np.log10(lc(np.linspace(-0.8,0.8,1000), radius_1=radius_1, k=k, incl = incl,ldc_1_1=ldc_1, ldc_1_2=ldc_2, light_3 = third_light))
    depth = (np.max(depth_lc) - np.min(depth_lc))*np.ones(len(samples))

    ax_planet_lc[0].scatter(phase,mag,c='k',s=5)
    ax_planet_lc[0].scatter(phase-1,mag,c='k',s=5)
    ax_planet_lc[0].plot(np.linspace(-0.8,0.8,1000), zp-2.5*np.log10(lc(np.linspace(-0.8,0.8,1000), radius_1=radius_1, k=k, incl = incl,ldc_1_1=ldc_1, ldc_1_2=ldc_2, light_3 = third_light)) , 'r')
    ax_planet_lc[0].invert_yaxis()
    ax_planet_lc[0].set_xlim(-0.2,0.8)
    ax_planet_lc[0].set_xlabel('Phase')
    ax_planet_lc[0].set_ylabel('Mag')

    ax_planet_lc[1].scatter(phase,mag,c='k',s=5)
    ax_planet_lc[1].scatter(phase-1,mag,c='k',s=5)
    ax_planet_lc[1].plot(np.linspace(-0.8,0.8,1000), zp-2.5*np.log10(lc(np.linspace(-0.8,0.8,1000), radius_1=radius_1, k=k, incl = incl,ldc_1_1=ldc_1, ldc_1_2=ldc_2, light_3 = third_light)) , 'r')
    ax_planet_lc[1].invert_yaxis()
    ax_planet_lc[1].set_xlim(-2*widths[i]/periods[i],2*widths[i]/periods[i])
    ax_planet_lc[1].set_xlabel('Phase')
    ax_planet_lc[1].set_ylabel('Mag')
    #ax_planet_lc[0].set_title('Planet model\nloglike = {:.2f}'.format(np.max(sampler.get_log_prob(flat=True, discard=args.burn_in))))
    ax_planet_lc[0].set_ylim(2*depth[0], -3*depth[0])
    ax_planet_lc[1].set_ylim(2*depth[0], -3*depth[0])
    ax_planet_lc[0].grid()
    ax_planet_lc[1].grid()

    #fig_planet_lc.subplots_adjust(left = 0.2, top=0.1, bottom = 0.05)
    fig_planet_lc.tight_layout()
    fig_planet_lc.savefig('{:}/{:}/{:}_{:}_best_model.png'.format(args.savedir, prod_id, obj_ids[i], PEAKS[i]))
    plt.close()


    planet_pars = np.array([1e3*depth, 24*funcs.transit_width(samples.T[2], samples.T[3], samples.T[4], P=samples.T[1]), funcs.rhostar(samples.T[2], samples.T[1]) , np.random.normal(star_radius, star_radius_err, samples.shape[0])*samples.T[3], 180*np.arccos(samples.T[4]*samples.T[2])/np.pi  ]).T
    planet_pars_labels = ["\n\nDepth [mmag]","\n\nDuration [hrs]", "\n\n $\\rho_\\star$", "\n\nR$_2$ [R$_\\odot$]", "\n\nincl [deg]"]
    try:
        fig_corner_planet_calculated =  corner.corner(planet_pars, labels=planet_pars_labels, truths=planet_pars[best_idx], quantiles=(0.16, 0.84), levels=(1-np.exp(-0.5),), smooth=1, plot_contours=False) 
        fig_corner_planet_calculated.subplots_adjust(bottom=0.15)
        fig_corner_planet_calculated.savefig('{:}/{:}/{:}_{:}_corner_calculated.png'.format(args.savedir, prod_id, obj_ids[i],PEAKS[i]))
    except : pass
    plt.close()

    # Now get the final samples WITH labels    
    planet_samples = np.hstack((samples, planet_pars))
    best_step = planet_samples[best_idx] 
    low_err = best_step - np.percentile(planet_samples, 16, axis=0)
    high_err = np.percentile(planet_samples, 84, axis=0) - best_step
    parameters = np.vstack((best_step,low_err, high_err)).T.flatten()
    parameters_labels = ['t_zero_planet', 'period', 'radius_1', 'k', 'b', 'h1', 'h2', 'zp', 'J', 'depth_mmag', 'duration_hrs', 'rho_star_1', 'R2', 'incl'] 
    parameters_labels_all = [] 
    for j in range(len(parameters_labels)) : 
        parameters_labels_all.append(parameters_labels[j])
        parameters_labels_all.append(parameters_labels[j]+'_err_low')
        parameters_labels_all.append(parameters_labels[j]+'_err_high')


    while True:
        if not os.path.isfile(lock_fname):
            # create the file
            flock = open(lock_fname, "w+")
            flock.close()
            fp = open(fname, 'a+')
            fp.write('{:},{:},'.format(obj_ids[i], PEAKS[i]))
            fp.write(','.join([str(i) for i in parameters]))
            fp.write(',{:},{:},{:},{:},{:},{:}\n'.format(third_light,star_radius, star_radius_err, star_teff, star_teff_err, np.max(sampler.get_log_prob(flat=True, discard=args.burn_in))))
            fp.close() 
            try : os.remove(lock_fname)
            except : pass 
            break
        else : 
            print('Im sleeping')
            time_package.sleep(3)
    plt.close()


    # Now plot jovian desert
    exoplanetdata = Table.read(args.exoplanetdata, format='csv')
    plt.scatter(exoplanetdata['PER'], exoplanetdata['R'], c='k', s=5, alpha = 0.1)
    plt.gca().set_yscale('log') 
    plt.gca().set_xscale('log') 

    l1 = lines.Line2D([0.15, 0.6], [0.9, 0.65], transform=plt.gcf().transFigure, figure=plt.gcf(), linestyle='--', c='k')
    l2 = lines.Line2D([0.15, 0.6], [0.35, 0.65], transform=plt.gcf().transFigure, figure=plt.gcf(), linestyle='--', c='k')
    plt.gcf().lines.extend([l1, l2])
    plt.text(0.15,0.5, "Sub-Jovian\ndesert")

    Rjup = 69911 # km 
    Rsaturn = 58232 # km 
    Rneptune = 24622 # km
    Rearth = 6356

    plt.axhline(1., ls='-.', c='b', alpha = 0.4)
    plt.axhline(Rsaturn/Rjup, ls='-.', c='b', alpha = 0.4)
    plt.axhline(Rneptune/Rjup, ls='-.', c='b', alpha = 0.4)
    plt.axhline(Rearth/Rjup, ls='-.', c='b', alpha = 0.4)
    plt.text(200,1.03,'Jupiter')
    plt.text(200,Rsaturn/Rjup + .03,'Saturn')
    plt.text(200,Rneptune/Rjup + .01,'Neptune')
    plt.text(200,Rearth/Rjup + .001,'Earth')

    plt.xlabel('Orbital period [d]')
    plt.ylabel('Radius [R_Jup]')
    plt.tight_layout()

    old_x = plt.gca().get_xlim()
    old_y = plt.gca().get_ylim()

    x = parameters[3]
    y = parameters[-6]
    xerr = np.max(parameters[4:6])
    yerr = np.max(parameters[-5:-3])
    (_, caps, _) = plt.errorbar(x,y, xerr=xerr, yerr=yerr, c='r')
    for cap in caps:
        cap.set_color('red')
        cap.set_markeredgewidth(100)
    plt.xlim(*old_x)
    plt.ylim(*old_y)

    plt.savefig('{:}/{:}/{:}_{:}_jovian_desert.png'.format(args.savedir, prod_id, obj_ids[i], PEAKS[i]))
    plt.close()


if __name__=="__main__":
    # Parse arguments 
    args = parser.parse_args()

    # Now get lightcurves 
    megafile = fits.open(args.fitsfile)
    lightcurves = megafile[5] 
    prod_id = str(megafile[0].header['PROD_ID'])
    obj_ids = megafile[4].data["OBJ_ID"] 
    periods = megafile[4].data["PERIOD"] / 86400.           # Period in days                '
    epochs = megafile[4].data["EPOCH"] / 86400. + 6658.5    # epoch in HJD-2450000 / days
    widths = megafile[4].data["WIDTH"] /  86400.            # Width in days
    depths = megafile[4].data["DEPTH"]* -1.                 # depth in mmag
    radius_1s = np.clip(np.pi*widths/periods,0.001,0.79)
    ks = np.clip(np.sqrt(depths), 0.0001, 0.79)
    PEAKS =megafile[4].data["RANK"]
    LC_IDX = megafile[4].data['LC_IDX'] -1 # consistent with fortran 1 index 

    # Now create the working directory 
    os.system('mkdir -p {:}/{:}'.format(args.savedir, prod_id))

    fname = '{:}/{:}/fit_table.dat'.format(args.savedir, prod_id)
    f = open(fname, "w+")
    parameter_labels = ['t_zero_planet', 'period_planet', 'radius_1_planet', 'k_planet', 'b_planet', 'h1_planet', 'h2_planet', 'zp_planet', 'J_planet', 'depth_mmag_planet', 'duration_hrs_planet', 'rho_star_1_planet', 'R2_planet', 'incl_planet'] 
    headers_all=[]
    for i in range(len(parameter_labels)) : 
        headers_all.append(parameter_labels[i])
        headers_all.append(parameter_labels[i]+'_err_low')
        headers_all.append(parameter_labels[i]+'_err_high')
    f.write('OBJ_ID,PEAKS,')
    f.write(','.join(headers_all))
    f.write(',{:},{:},{:},{:},{:},{:}\n'.format('Light_3', 'star_radius', 'star_radius_err', 'star_teff', 'star_teff_err', 'loglike_planet'))
    f.close()

    lock_fname = '{:}/{:}/.lock'.format(args.savedir,prod_id)
    try : os.remove(lock_fname)
    except : pass

    #fit_ngts_target(0)
    pool = Pool(10)
    pool.map(fit_ngts_target, range(len(obj_ids))[:100])
    #fit_ngts_target(3, lightcurves, prod_id, obj_ids, periods, epochs, widths, depths, radius_1s, ks, PEAKS,power2_table, LC_IDX) 