#!/home/sam/anaconda3/bin/python

import matplotlib.pyplot as plt 
import os,sys 
import matplotlib.gridspec as gridspec
import argparse 
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path
from bruce.binarystar.lc import _lc , kernel_lc, sum_reduce, lc
import numba 
from scipy.signal import find_peaks
from lightkurve.lightcurve import TessLightCurve
import time as time_pack
from astropy.stats import sigma_clip
from numpy import mean


description = '''monofind'''

# Argument parser
parser = argparse.ArgumentParser('predict', description=description)
'''
parser.add_argument('-f', 
                    '--dec',
                    help='The Dec in deg.', 
                    default=-12.35998970810, type=float)      

parser.add_argument('--complete', action="store_true", default=False, help="Only complete nights")


parser.add_argument('-g', 
                '--observatory',
                help='The Observatory.',
                default='Paranal') 
'''
parser.add_argument('--trial', action="store_true", default=False, help="Plot each night")
parser.add_argument('--flatten', action="store_true", default=False, help="Flatten the LC")
parser.add_argument('--remove_dropouts', action="store_true", default=False, help="Plot each night")



parser.add_argument("filename",
                    help='The filename.')

parser.add_argument('-g', 
                '--mask',
                help='The mask used to exclude bad regions. A two-column text file specifying the start and end of bad points.',
                default='None') 

parser.add_argument('-i', 
                '--window_length',
                help='The flattening window.',
                default=11, type=int)

parser.add_argument('-j', 
                '--saveplace',
                help='Where to save',
                default='/ngts/scratch/monofind/')

def lc_sort(time, mag, mag_err):
    sort = sorted(zip(time, mag, mag_err))
    time = np.array([i[0] for i in sort])
    mag = np.array([i[1] for i in sort])
    mag_err = np.array([i[2] for i in sort])
    return time, mag, mag_err


def lc_resample(time, mag, mag_err, width, weighted_mean):
    time_new = np.arange(np.min(time) - width/2, np.max(time) + width/2, 0.5/24) 
    mag = np.interp(time_new, time, mag, left = weighted_mean, right = weighted_mean)
    mag_err = np.interp(time_new, time, mag_err, left = mag_err[0], right = mag_err[0])
    return time_new, mag, mag_err 

@numba.njit(parallel=True)
def main_func(time, mag, mag_err, weighted_mean, model_period, b, spots, chi_ref,cube):
    for i in numba.prange(radius_1s.shape[0]):
        for j in numba.prange(ks.shape[0]):
            for k in numba.prange(time.shape[0]):
                incl = np.arccos(b*radius_1s[i])

                cube[i,j,k, 0] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = np.pi/2.,
                                    ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref


                
                cube[i,j,k, 1] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = np.pi/2.,
                                    ld_law_1=0, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref

                cube[i,j,k, 2] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = incl,
                                    ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref



def get_lc(time, modelx, modely, t_zero, weighted_mean):
    return np.interp(time, modelx + t_zero, modely, left=weighted_mean, right = weighted_mean) 

def get_lc_loglike(time, mag, mag_err, modelx, modely, t_zero, weighted_mean,chi_ref):
    wt = 1.0/(mag_err**2)
    model = get_lc(time, modelx, modely, t_zero, weighted_mean)
    return -0.5*np.sum((mag-model)**2*wt - np.log(wt)) - chi_ref


def transit_duration(period, radius_1, k, b, incl):
    return period*np.arcsin(radius_1*np.sqrt((1 + k)**2 - b**2)/np.sin(incl))/np.pi

if __name__=='__main__':
    # Parse the arguments
    args = parser.parse_args()
    savepath = args.saveplace
    lockfile = args.saveplace + '.lock'

    # load the data 
    try    : time, mag, mag_err = np.loadtxt(args.filename).T 
    except : time, mag, mag_err, f, f_err = np.loadtxt(args.filename).T 

    # Sort the time axis 
    time, mag, mag_err = lc_sort(time, mag, mag_err)
    mag_err = np.ones(time.shape[0])*1e-3

    # Mask bad data 
    if args.mask is not 'None':
        mask_array = np.loadtxt(args.mask)

        mask_master = np.zeros(time.shape[0], dtype = np.bool)

        for mask in mask_array : 
            mask_master = mask_master + ((time > mask[0]) & (time < mask[1]))

        time = time[~mask_master]
        mag = mag[~mask_master]
        mag_err = mag_err[~mask_master] 


    # if flatten, flatten 
    if args.flatten:
        flux = 10**(-0.4*mag)
        s = TessLightCurve(time, flux, 3000*1e-6*np.ones(time.shape[0]))
        #s = s.flatten(window_length=args.window_length) 
        s = s.flatten() 

        time, mag = s.time, -2.5*np.log10(s.flux)

    if args.remove_dropouts:
        drop_out_thresh = 0.05
        for i in range(1, time.shape[0]-1):
            d1 = abs(mag[i] - mag[i-1])
            d2 = abs(mag[i] - mag[i+1])
            if ((d1 > drop_out_thresh) and (d2 > drop_out_thresh)) : mag[i] = np.random.normal(0,np.std(mag))

    # Get the weighted mean 
    weighted_mean = np.median(mag) #, weights = mag_err, axis=0)

    # Now we need to resample the time axis
    time, mag, mag_err = lc_resample(time, mag, mag_err, 16/24, weighted_mean)

    # Now to minimise damage, we should re-mask and set the masked regions to mag=0
    if args.mask is not 'None':
        mask_array = np.loadtxt(args.mask)

        mask_master = np.zeros(time.shape[0], dtype = np.bool)

        for mask in mask_array : 
            mask_master = mask_master + ((time > mask[0]) & (time < mask[1]))

        mag[mask_master] = np.random.normal(0,np.std(mag), mag[mask_master].shape[0])

    # Now get reference chi 
    chi_ref = get_lc_loglike(time, mag, mag_err,time, np.ones(time.shape[0])*weighted_mean, 0, weighted_mean,0.)

    if args.trial : 
        f = plt.figure(figsize=(15,5))
        plt.scatter(time, mag, c='k', s=10)
        plt.axhline(weighted_mean)
        plt.title('Loglike ref = ' + str(chi_ref))
        plt.gca().invert_yaxis()
        plt.xlabel('Time')
        plt.ylabel('Mag')
        plt.show()
        exit()

    # Get the time span
    model_period = 10.

    # Search for transit widths between 1 and 16 hours
    transit_durations = np.arange(1/24, 16/24 + 0.5/24, 0.5/24)
    radius_1s = np.pi*transit_durations/model_period
    radius_1 = np.linspace(0.01,0.5,30)

    # Now search for a variatey of transit depths from 0.1 mmag to 3 mmag 
    mag_depths = np.arange(0.1e-3,3e-3 + 0.1e-3, 0.1e-3)
    ks = np.sqrt(mag_depths)
    ks = np.linspace(np.sqrt(np.std(mag)*2), 0.5, 30)
    mag_depths = ks**2 

    # Now get the transit durations 
    #print('Calculating transit durations')
    Tdur = np.zeros((radius_1s.shape[0], ks.shape[0]))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            Tdur[i,j] = transit_duration(model_period, radius_1s[i], ks[j], 0, np.pi/2)
    
    # Now get the transit models
    #print('Creating models')
    Npoints_in_transit_model = 100
    Models =  np.zeros((radius_1s.shape[0], ks.shape[0], 2, Npoints_in_transit_model))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            Models[i,j,0] = np.linspace(-Tdur[i,j]/2, Tdur[i,j]/2, Npoints_in_transit_model) 
            Models[i,j,1] = weighted_mean-2.5*np.log10(lc(Models[i,j,0], period = model_period, radius_1=radius_1s[i], k=ks[j]))
            


    #print('Calculating loglike for each recubed time-stamp...')
    Cube = np.zeros((radius_1s.shape[0], ks.shape[0], time.shape[0]))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            for k in range(time.shape[0]):
                Cube[i,j,k] = get_lc_loglike(time, mag, mag_err, Models[i,j,0], Models[i,j,1], time[k], weighted_mean, chi_ref)


    # Now cut for a better plot
    fig1 = plt.figure()
    plt.imshow(np.max(Cube, axis=2), interpolation='bilinear', aspect='auto', origin='lower')
    levels = np.linspace(0, np.max(Cube[:,:]) ,10)
    plt.contour(np.max(Cube, axis=2), color='k', )

    # get best index 
    best = np.unravel_index(np.max(Cube, axis=2).argmax(), np.max(Cube, axis=2).shape)
    plt.axhline(best[0])
    plt.axvline(best[1])

    xticks = np.arange(0, mag_depths.shape[0], 2)
    xtick_labels = np.interp(xticks, np.linspace(0,mag_depths.shape[0]+1,mag_depths.shape[0]), mag_depths*1e3)
    xtick_labels = ['{:.2f}'.format(i) for i in xtick_labels]
    plt.xticks(xticks, xtick_labels, rotation = 45)
    plt.xlabel('$Depth$ [mmag]')


    yticks = np.arange(0, transit_durations.shape[0], 2)
    ytick_labels = np.interp(yticks, np.linspace(0,transit_durations.shape[0]+1,transit_durations.shape[0]), transit_durations*24)
    ytick_labels = ['{:.2f}'.format(i) for i in ytick_labels]
    plt.yticks(yticks, ytick_labels, rotation = 45)
    plt.ylabel('$T_{dur}$ [hrs]')
    plt.gcf().subplots_adjust(bottom=0.15)

    median = np.median(Cube[best])
    std = np.std(Cube[best])
    Cube= (Cube - median)

    height = 10*np.std( Cube[best][(Cube[best] > np.percentile(Cube[best],10)) & (Cube[best] < np.percentile(Cube[best], 90))]  )
    peaks, _ = find_peaks(Cube[best], height=height, distance = 24)
    time_ = np.linspace(time[0], time[-1], 10000)

    if len(peaks) > 0:
        fig1.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_power_hovmoller.png')
        plt.close(fig1)

        fig2,axs = plt.subplots(nrows = len(peaks) + 1, ncols = 1, figsize=(10,len(peaks)*5),)

        axs[0].plot(time, Cube[best], 'k')

        axs[0].plot(time[peaks], Cube[best][peaks], "x")
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel(r'$\mathcal{L} - \mathcal{L}_{\rm wm}$')
        axs[0].set_title('Number of peaks: {:}'.format(len(peaks)))
        axs[0].axhline(height, c='b', ls='--')
        axs[0].axhline(0,ls='--', color='k')
        for i in range(peaks.shape[0]) : axs[0].text(time[peaks][i]+0.25, Cube[best][peaks][i], '{:}'.format(i+1), fontsize=15)

        depth = radius_1s[best[0]]**2
        width = Tdur[best]

        while True:
            if not os.path.isfile(lockfile):
                flock = open(lockfile, 'w+')
                flock.close()

                if not os.path.isfile(savepath+'monofind_results.dat'):
                    f = open(savepath+'monofind_results.dat', "w+")
                    f.write('filename,peak, t_cen, radius_1, k, width, depth, snr_peak\n')
                else:
                    f = open(savepath+'monofind_results.dat', "a")

                for i in range(peaks.shape[0]):
                    ax = axs[i+1]

                    # First, we need to find the best radius_1 and k at each peak
                    current_best = np.unravel_index(Cube[:,:,peaks[i]].argmax(), Cube[:,:,peaks[i]].shape)
                    radius_1 = radius_1s[current_best[0]]
                    width = transit_durations[current_best[0]]
                    k = ks[current_best[1]]
                    depth = mag_depths[current_best[1]]

                    ax.scatter(time,mag, c='k',s=10)
                    ax.plot(time_, get_lc(time_, Models[current_best[0], current_best[1],0], Models[current_best[0], current_best[1],1], time[peaks[i]], weighted_mean), 'r')

                    ax.set_xlabel('Time')
                    ax.set_ylabel('Mag')
                    ax.set_xlim(time[peaks[i]]-4*width, time[peaks[i]]+4*width)
                    ax.set_ylim(3*depth, -depth, )

                    f.write('{:}, {:}, {:}, {:}, {:}, {:}, {:}, {:}\n'.format(args.filename, i+1, time[peaks[i]], radius_1, k, width, depth, np.max(Cube[current_best]) / np.std(Cube[current_best]) ))

                fig2.align_ylabels(axs)
                fig2.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_monofind_transits_found.png')

                f.close()
                os.system('rm {:}'.format(lockfile))
                break
            else : 
                #print('Im sleeping')
                time_pack.sleep(1)

    plt.close()

    '''

    cube = np.zeros((radius_1s.shape[0], ks.shape[0], time.shape[0], 3))



    main_func(time, mag, mag_err, weighted_mean, model_period, b, spots, chi_ref,cube)


    # Now cut for a better plot
    fig1 = plt.figure()
    plt.imshow(np.max(cube, axis=2).max(axis=2), interpolation='bilinear', aspect='auto', origin='lower')
    levels = np.linspace(0, np.max(cube, axis=2).max() ,10)
    plt.contour(np.max(cube, axis=2).max(axis=2),levels=levels, color='k', )
    

    # get best index 
    best = np.unravel_index(np.max(cube, axis=2).max(axis=2).argmax(), np.max(cube, axis=2).max(axis=2).shape)
    plt.axhline(best[0])
    plt.axvline(best[1])

    xticks = np.arange(0, ks.shape[0], 2)
    xtick_labels = np.interp(xticks, np.linspace(0,ks.shape[0]+1,ks.shape[0]), ks)
    xtick_labels = ['{:.2f}'.format(i) for i in xtick_labels]
    plt.xticks(xticks, xtick_labels, rotation = 45)
    plt.xlabel('$R_2/R_*$')


    yticks = np.arange(0, radius_1s.shape[0], 2)
    ytick_labels = np.interp(yticks, np.linspace(0,radius_1s.shape[0]+1,radius_1s.shape[0]), radius_1s)
    ytick_labels = ['{:.2f}'.format(i) for i in ytick_labels]
    plt.yticks(yticks, ytick_labels, rotation = 45)
    plt.ylabel('$R_*/a$')
    plt.gcf().subplots_adjust(bottom=0.15)


    median = np.median(cube[best][:,0])
    std = np.std(cube[best][:,0])
    cube= (cube - median) - std
    peaks, _ = find_peaks(cube[best][:,0], height=0, threshold=2*std)

    fig2 = plt.figure()
    plt.plot(time, cube[best][:,0])



    if len(peaks) > 0:
        # first, save 
        np.save(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_cube.npy', cube)
        fig1.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_monofind_power.png')
        fig2.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_monofind_diffL.png')

        fig3,axs = plt.subplots(nrows = len(peaks) + 1, ncols = 1, figsize=(15,len(peaks)*5),)
        axs[0].plot(time, cube[best][:,0], 'k')

        axs[0].plot(time[peaks], cube[best][:,0][peaks], "x")
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel(r'$\mathcal{L} - \mathcal{L}_{\rm wm}$')
        axs[0].set_title('Number of peaks: {:}'.format(len(peaks)))

        axs[0].axhline(0,ls='--', color='k')
        for i in range(peaks.shape[0]) : axs[0].text(time[peaks][i]+0.25, cube[best][:,0][peaks][i], '{:}'.format(i+1), fontsize=15)

        time_ = np.linspace(time[0], time[-1], 10000)

        f = open(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_monofind_results.dat', "w")
        f.write('peak, t_cen, radius_1, k, width, depth, type, snr_peak\n')

        for i in range(peaks.shape[0]):
            ax = axs[i+1]

            ax.scatter(time,mag, c='k',s=10)
            best_time = time[peaks[i]]

            # Now search for the best transit shape at that epoch 
            temp_cube = cube[:,:,peaks[i],:]
            temp_best = np.unravel_index(temp_cube.argmax(), temp_cube.shape)

            radius_1 = radius_1s[temp_best[0]] 
            width = radius_1*model_period/np.pi 
            k = ks[temp_best[1]] 
            depth = k**2 

            type = ''

            if temp_best[2]==0 : 
                ax.plot(time_, weighted_mean - 2.5*np.log10(lc(time_, t_zero = time[peaks[i]], period=model_period, radius_1 = radius_1 , k=k, ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8)))
                ax.set_facecolor('xkcd:sky blue')
                type='limbed'

            if temp_best[2]==1 : 
                ax.plot(time_, weighted_mean - 2.5*np.log10(lc(time_, t_zero = time[peaks[i]], period=model_period, radius_1 = radius_1 , k=k, ld_law_1=0, ldc_1_1=0.8, ldc_1_2=0.8)))
                ax.set_facecolor('xkcd:red')
                type='uniform'

            if temp_best[2]==2 : 
                b = 0.8
                incl = 180*np.arccos(radius_1*b)/np.pi
                ax.plot(time_, weighted_mean - 2.5*np.log10(lc(time_, t_zero = time[peaks[i]], period=model_period, radius_1 = radius_1 , k=k, ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8, incl=incl)))
                ax.set_facecolor('xkcd:teal')
                type='grazing'


            f.write('{:}, {:}, {:}, {:}, {:}, {:}, {:}, {:}\n'.format(i+1, time[peaks[i]], radius_1, k, width, depth, type, cube[best][:,temp_best[2]][peaks[i]]/ np.std(cube[best][:,temp_best[2]]) ))

            ax.set_xlabel('Time')
            ax.set_ylabel('Mag')
            ax.set_xlim(time[peaks[i]]-4*width, time[peaks[i]]+4*width)
            ax.set_ylim(2*depth, -depth, )
        fig3.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_monofind_transits_found.png')
    plt.close()
    '''
    
