#!/home/u1870241/anaconda3/bin/python



#
# Plan for prepare
# Folder layout
# -> LC files (binned to 30 minutes)
# -> SGE files
# -> NGTSFIT (plots, backends etc)
#
HOME_PATH = '/ngts/scratch/fitpipe'
#NGTSFIT_PATH = '/home/sam/anaconda3/bin/ngtsfit'
NGTSFIT_PATH = '/home/u1870241/anaconda3/bin/ngtsfit'


FITPARS = 't_zero period radius_1 k b zp h1 h2 J'
FITFLAGS = '--emcee_steps 50 --emcee_burn_in 25'
GP_FITPARS = 't_zero period radius_1 k b zp h1 h2 log_amp log_timescale  log_period log_factor log_sigma --gp'


import matplotlib.pyplot as plt
from astropy.io import fits 
from astropy.table import Table, vstack
import numpy as np 
import sys,os, math
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia
from astropy import constants
#from pycheops import funcs 
#from pycheops.ld import stagger_claret_interpolator
#ld_interpolater = stagger_claret_interpolator(passband='NGTS')
import argparse 
from multiprocessing import Pool
from scipy.stats import chisquare, sem
from contextlib import contextmanager
from collections import defaultdict
import pymysql
import argparse 
import time
from tqdm import tqdm
from astroquery.mast import Catalogs
import glob



parser = argparse.ArgumentParser('prepare')


parser.add_argument("filename",
                    help='The filename of the binary star information')

parser.add_argument("--crossmatch_batch",
                    help='Batch crossmatching', default = 'no')

parser.add_argument('--submit_jobs', action="store_true", default=False)
parser.add_argument('--make_jobs', action="store_true", default=False)
parser.add_argument('--crossmatch', action="store_true", default=False)




@contextmanager
def open_db(host='ngtsdb', db='ngts_archive', user='u1870241', cur_class='list'):
    """
    Reusable database connection manager
    """
    if cur_class == 'list':
        with pymysql.connect(host=host,
                             db=db,
                             user=user) as cur:
            yield cur
    else:
        with pymysql.connect(host=host,
                             db=db,
                             user=user,
                             cursorclass=pymysql.cursors.DictCursor) as cur:
            yield cur 


def query_ra_dec_from_objid_prodid(prod_id, obj_id):
    qry = 'SELECT cat.ra_deg, cat.dec_deg FROM ngts_archive.orion_runs AS orun LEFT JOIN ngts_archive.catalogue AS cat ON orun.cat_prod_id = cat.cat_prod_id WHERE orun.prod_id = {:} AND cat.obj_id = {:} ;'.format(prod_id, obj_id)

    with open_db(cur_class='list') as cur: 
        cur.execute(qry) 
        results = cur.fetchall() 
        return results[0]


def find_mega_fits(string_to_match):
    qry = "select * from prod_dir WHERE directory LIKE "
    qry += "'%"
    qry += string_to_match
    qry += "%';"
    print(qry)
    with open_db(cur_class='list', db='ngts_pipe') as cur: 
        cur.execute(qry) 
        results = cur.fetchall() 
        return results

        

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''

        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   

        return time_bin, flux_bin, err_bin

if __name__=="__main__":

    args = parser.parse_args()


    if args.crossmatch:
        megafile = fits.open(args.filename)
        prod_id = str(megafile[0].header['PROD_ID'])
        obj_ids = megafile[4].data["OBJ_ID"] 
        unique_obj_ids = np.unique(obj_ids)
        prod_ids = np.repeat(prod_id, np.unique(obj_ids).shape[0])
        PEAKS =megafile[4].data["RANK"] 

        # First have a look to see if it's already been crossmatched
        crossmatch_file = '{:}/crossmatch/{:}.fits'.format(HOME_PATH,prod_id)
        if os.path.isfile(crossmatch_file) : 
            print('{:} has already been crossmatched'.format(prod_id))
            exit()
        source_id_gaia = []
        source_id_tic = []

        for i in tqdm(range(len(unique_obj_ids))[:]) : 
            ra,dec = query_ra_dec_from_objid_prodid(prod_id, unique_obj_ids[i])
            coord = SkyCoord(ra=ra, dec=dec, unit=(u.degree, u.degree), frame='icrs')
            width = u.Quantity(0.01, u.deg)
            height = u.Quantity(0.01, u.deg)
            r_gaia = Gaia.query_object_async(coordinate=coord, width=width, height=height)
            r_tic = Catalogs.query_object('{:} {:}'.format(ra,dec), radius=.01, catalog="TIC")

            if len(r_gaia) > 0 : source_id_gaia.append(int(r_gaia['source_id'][0]))
            else :  source_id_gaia.append(-99)

            if len(r_tic) > 0 : source_id_tic.append(int(r_tic['ID'][0]))
            else : source_id_tic.append(-99)

        current_table = Table(names = ['prod_id', 'obj_id', 'gaia_source_id', 'tic_ID'], dtype=['U25','U25',int, int])
        print()
        for i in tqdm(range(len(unique_obj_ids))[:]) : current_table.add_row([prod_ids[i], unique_obj_ids[i], source_id_gaia[i], source_id_tic[i]])
        current_table.write(crossmatch_file)

                


    if args.crossmatch_batch != 'no':
        filedirs = find_mega_fits(args.crossmatch_batch)
        filedirs = [i[1] for i in filedirs] 
        files = [glob.glob('{:}/*.fits'.format(i))[0] for i in filedirs]
        for i in files : os.system('fitpipe --crossmatch {:}'.format(i))

    if args.submit_jobs:
        # Now get lightcurves 
        megafile = fits.open(args.filename)
        prod_id = str(megafile[0].header['PROD_ID'])
        obj_ids = megafile[4].data["OBJ_ID"] 
        PEAKS =megafile[4].data["RANK"] 
        for i in range(len(obj_ids))[:] : 
            os.system('qsub -p -999 {:}/SGEjobs/{:}_{:}_{:}.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i])) 
            os.system('qsub -p -999 {:}/SGEjobs/{:}_{:}_{:}_gp.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i])) 
            os.system('mv {:}/SGEjobs/{:}_{:}_{:}*.job {:}/SGEjobs/submitted'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i], HOME_PATH))

    elif args.make_jobs:
        print('Preparing ', args.filename)

        # First, we need to create the file structure
        os.system('mkdir -p {:}/LC_files'.format(HOME_PATH))
        os.system('mkdir -p {:}/SGEjobs/submitted'.format(HOME_PATH))
        os.system('mkdir -p {:}/NGTSFIT'.format(HOME_PATH))
        os.system('mkdir -p {:}/crossmatch'.format(HOME_PATH))

        # Now get lightcurves 
        megafile = fits.open(args.filename)
        lightcurves = megafile[5] 
        prod_id = str(megafile[0].header['PROD_ID'])
        obj_ids = megafile[4].data["OBJ_ID"] 
        periods = megafile[4].data["PERIOD"] / 86400.                     # Period in days                '
        epochs = 2450000 + megafile[4].data["EPOCH"] / 86400. + 6658.5    # epoch in HJD / days
        BLS_widths = megafile[4].data["WIDTH"] /  86400.                  # Width in days
        BLS_depths = megafile[4].data["DEPTH"]* -1.                       # depth in mmag
        PEAKS =megafile[4].data["RANK"]                                   # PEAK
        LC_IDX = megafile[4].data['LC_IDX'] -1                            # C-idx for LC 

        widths = megafile[4].data["WIDTH"] /  86400.            # Width in days
        depths = megafile[4].data["DEPTH"]* -1.                 # depth in mmag
        radius_1s = np.clip(np.pi*widths/periods,0.001,0.79)
        ks = np.clip(np.sqrt(depths), 0.0001, 0.79)


        for i in tqdm(range(len(obj_ids))[:]):
            # Get and save array for ngtsfit
            time, mag, mag_err = lightcurves.data[LC_IDX[i]]
            time = 2450000 + time / 86400. + 6658.5    # epoch in HJD / days
            sort = sorted(zip(time, mag, mag_err))
            time = np.array([i[0] for i in sort])
            mag = np.array([i[1] for i in sort])
            mag_err = np.array([i[2] for i in sort])
            time, mag, mag_err = lc_bin(time, mag, 0.5/24)
            tmp = np.array([time.tolist(), mag.tolist(), mag_err.tolist()]).T
            np.savetxt('{:}/LC_files/{:}_{:}_binned.dat'.format(HOME_PATH, prod_id, obj_ids[i]), tmp)

            # Now make the SGE job to go with it, dumping plots, backends and errors to the same directory
            f = open('{:}/SGEjobs/{:}_{:}_{:}.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]), "w+"  )
            f.write('#$ -N fitpipe_{:}_{:}_{:}'.format(prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -o {:}/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -e {:}/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -l h=ngts01.local')
            f.write('\n\n')
            f.write('{:} {:}/LC_files/{:}_{:}_binned.dat --threads 1 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))
            f.close()

            # Now do GP fit
            f = open('{:}/SGEjobs/{:}_{:}_{:}_gp.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]), "w+"  )
            f.write('#$ -N fitpipe_{:}_{:}_{:}'.format(prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -o {:}/NGTSFIT/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -e {:}/NGTSFIT/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -l h=ngts01.local')
            f.write('\n\n')
            f.write('{:} {:}/LC_files/{:}_{:}_binned.dat --threads 1 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --gp --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                GP_FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))
            f.close()

            '''
            os.system('tsp {:} {:}/LC_files/{:}_{:}_binned.dat --threads 1 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --gp --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                GP_FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))
            '''