#!python



#
# Plan for prepare
# Folder layout
# -> LC files (binned to 30 minutes)
# -> SGE files
# -> NGTSFIT (plots, backends etc)
#
HOME_PATH = '/ngts/scratch/fitpipe'
NGTSFIT_PATH = '/home/sam/anaconda3/bin/ngtsfit'
#NGTSFIT_PATH = '/home/u1870241/anaconda3/bin/ngtsfit'


FITPARS = 't_zero period radius_1 k b zp J'
FITFLAGS = '--emcee_steps 50000 --emcee_burn_in 40000'
GP_FITPARS = 't_zero period radius_1 k b zp log_amp log_timescale  log_period log_factor log_sigma --gp'


import matplotlib.pyplot as plt
from astropy.io import fits 
import numpy as np 
import sys,os, math
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia
from astropy import constants
#from pycheops import funcs 
#from pycheops.ld import stagger_claret_interpolator
#ld_interpolater = stagger_claret_interpolator(passband='NGTS')
import argparse 
from multiprocessing import Pool
from scipy.stats import chisquare, sem
from contextlib import contextmanager
from collections import defaultdict
import pymysql
import argparse 
import time
from tqdm import tqdm


parser = argparse.ArgumentParser('prepare')


parser.add_argument("filename",
                    help='The filename of the binary star information')


parser.add_argument('--submit', action="store_true", default=False)
parser.add_argument('--make_jobs', action="store_true", default=False)


def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''

        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   

        return time_bin, flux_bin, err_bin

if __name__=="__main__":

    args = parser.parse_args()

    if args.submit:
        # Now get lightcurves 
        megafile = fits.open(args.filename)
        prod_id = str(megafile[0].header['PROD_ID'])
        obj_ids = megafile[4].data["OBJ_ID"] 
        PEAKS =megafile[4].data["RANK"] 
        for i in range(len(obj_ids)) : 
            os.system('qsub -p -100 {:}/SGEjobs/{:}_{:}_{:}.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i])) 
            os.system('mv {:}/SGEjobs/{:}_{:}_{:}.job {:}/SGEjobs/submitted/{:}_{:}_{:}.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i], HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))

    elif args.make_jobs:
        print('Preparing ', args.filename)

        # First, we need to create the file structure
        os.system('mkdir -p {:}/LC_files'.format(HOME_PATH))
        os.system('mkdir -p {:}/SGEjobs/submitted'.format(HOME_PATH))
        os.system('mkdir -p {:}/NGTSFIT'.format(HOME_PATH))

        # Now get lightcurves 
        megafile = fits.open(args.filename)
        lightcurves = megafile[5] 
        prod_id = str(megafile[0].header['PROD_ID'])
        obj_ids = megafile[4].data["OBJ_ID"] 
        periods = megafile[4].data["PERIOD"] / 86400.                     # Period in days                '
        epochs = 2450000 + megafile[4].data["EPOCH"] / 86400. + 6658.5    # epoch in HJD / days
        BLS_widths = megafile[4].data["WIDTH"] /  86400.                  # Width in days
        BLS_depths = megafile[4].data["DEPTH"]* -1.                       # depth in mmag
        PEAKS =megafile[4].data["RANK"]                                   # PEAK
        LC_IDX = megafile[4].data['LC_IDX'] -1                            # C-idx for LC 

        widths = megafile[4].data["WIDTH"] /  86400.            # Width in days
        depths = megafile[4].data["DEPTH"]* -1.                 # depth in mmag
        radius_1s = np.clip(np.pi*widths/periods,0.001,0.79)
        ks = np.clip(np.sqrt(depths), 0.0001, 0.79)


        for i in tqdm(range(len(obj_ids))[:500]):
            # Get and save array for ngtsfit
            time, mag, mag_err = lightcurves.data[LC_IDX[i]]
            time = 2450000 + time / 86400. + 6658.5    # epoch in HJD / days
            sort = sorted(zip(time, mag, mag_err))
            time = np.array([i[0] for i in sort])
            mag = np.array([i[1] for i in sort])
            mag_err = np.array([i[2] for i in sort])
            time, mag, mag_err = lc_bin(time, mag, 0.5/24)
            tmp = np.array([time.tolist(), mag.tolist(), mag_err.tolist()]).T
            np.savetxt('{:}/LC_files/{:}_{:}_binned.dat'.format(HOME_PATH, prod_id, obj_ids[i]), tmp)

            # Now make the SGE job to go with it, dumping plots, backends and errors to the same directory
            f = open('{:}/SGEjobs/{:}_{:}_{:}.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]), "w+"  )
            f.write('#$ -N fitpipe_{:}_{:}_{:}'.format(prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -o {:}/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -e {:}/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -l h=ngts01.local')
            f.write('\n\n')
            f.write('{:} {:}/LC_files/{:}_{:}_binned.dat --threads 2 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))
            f.close()

            # Now do GP fit
            f = open('{:}/SGEjobs/{:}_{:}_{:}_gp.job'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]), "w+"  )
            f.write('#$ -N fitpipe_{:}_{:}_{:}'.format(prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -o {:}/NGTSFIT/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n#$ -e {:}/NGTSFIT/{:}_{:}_{:}.log'.format(HOME_PATH, prod_id, obj_ids[i], PEAKS[i]))
            f.write('\n\n')
            f.write('{:} {:}/LC_files/{:}_{:}_binned.dat --threads 1 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --gp --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                GP_FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))

            
            os.system('tsp {:} {:}/LC_files/{:}_{:}_binned.dat --threads 1 --t_zero {:} --period {:} --radius_1 {:} --k {:} --J 0.1 --emcee --gp --fitpars {:} {:} --name {:}_{:}_{:} --savepath {:}/NGTSFIT'.format(NGTSFIT_PATH,
                                                                HOME_PATH, prod_id, obj_ids[i],
                                                                epochs[i], periods[i],
                                                                radius_1s[i], ks[i],
                                                                GP_FITPARS, FITFLAGS,
                                                                prod_id, obj_ids[i], PEAKS[i],
                                                                HOME_PATH))
            f.close()

