#!/home/sam/anaconda3/bin/python

import numba, numba.cuda 
from bruce.binarystar import lc_fast as lc
import numpy as np, sys, os, math
import argparse
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import  sem
from scipy.signal import find_peaks

np.warnings.filterwarnings('ignore')



# LC bin 
def lc_bin(time, flux, bin_width):
	'''
	Function to bin the data into bins of a given width. time and bin_width 
	must have the same units
	'''

	edges = np.arange(np.min(time), np.max(time), bin_width)
	dig = np.digitize(time, edges)
	time_binned = (edges[1:] + edges[:-1]) / 2
	flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
	err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
	time_bin = time_binned[~np.isnan(err_binned)]
	err_bin = err_binned[~np.isnan(err_binned)]
	flux_bin = flux_binned[~np.isnan(err_binned)]   

	return time_bin, flux_bin, err_bin


def phaser(t,t0, p) : return ((t-t0)/p) - np.floor((t-t0)/p)

from astropy import constants

def period_grid(
    R_star,
    M_star,
    time_span,
    period_min=0,
    period_max=float("inf"),
    oversampling_factor=2,
    n_transits_min=1,
):
    """Returns array of optimal sampling periods for transit search in light curves
       Following Ofir (2014, A&A, 561, A138)"""

    if R_star < 0.1:
        text = (
            "Warning: R_star was set to 0.1 for period_grid (was unphysical: "
            + str(R_star)
            + ")"
        )
        warnings.warn(text)
        R_star = 0.1

    if R_star > 10000:
        text = (
            "Warning: R_star was set to 10000 for period_grid (was unphysical: "
            + str(R_star)
            + ")"
        )
        warnings.warn(text)
        R_star = 10000

    if M_star < 0.01:
        text = (
            "Warning: M_star was set to 0.01 for period_grid (was unphysical: "
            + str(M_star)
            + ")"
        )
        warnings.warn(text)
        R_star = 0.01

    if M_star > 1000:
        text = (
            "Warning: M_star was set to 1000 for period_grid (was unphysical: "
            + str(M_star)
            + ")"
        )
        warnings.warn(text)
        R_star = 1000

    R_star = R_star * constants.R_sun.value
    M_star = M_star * constants.M_sun.value
    time_span = time_span * 86400  # seconds

    # boundary conditions
    f_min = n_transits_min / time_span
    f_max = 1.0 / (2 * np.pi) * np.sqrt(constants.G.value * M_star / (3 * R_star) ** 3)

    # optimal frequency sampling, Equations (5), (6), (7)
    A = (
        (2 * np.pi) ** (2.0 / 3)
        / np.pi
        * R_star
        / (constants.G.value * M_star) ** (1.0 / 3)
        / (time_span * oversampling_factor)
    )
    C = f_min ** (1.0 / 3) - A / 3.0
    N_opt = (f_max ** (1.0 / 3) - f_min ** (1.0 / 3) + A / 3) * 3 / A

    X = np.arange(N_opt) + 1
    f_x = (A / 3 * X + C) ** 3
    P_x = 1 / f_x

    # Cut to given (optional) selection of periods
    periods = P_x / 86400
    selected_index = np.where(
        np.logical_and(periods > period_min, periods <= period_max)
    )

    number_of_periods = np.size(periods[selected_index])
    return periods[selected_index]  # periods in [days]


# Argument parser
parser = argparse.ArgumentParser('mbls')
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename from which to template search')

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=10, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.2, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.2, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-f', 
                    '--light_3',
                    help='The third light.',
                    default=0.0, type=float) 

parser.add_argument('-g', 
                    '--period_low',
                    help='Period low',

                    default=1, type=float)

parser.add_argument('-i', 
                    '--period_high',
                    help='Period high',
                    default=5, type=float) 	

parser.add_argument('-j', 
                    '--oversampling_factor',
                    help='oversampling factor',
                    default=2, type=int) 	


parser.add_argument('--gpu', action="store_true", default=False)

@numba.njit
def transit_width(r, k, b, P=1):
	"""
	Total transit duration.
	See equation (3) from Seager and Malen-Ornelas, 2003ApJ...585.1038S.
	:param r: R_star/a
	:param k: R_planet/R_star
	:param b: impact parameter = a.cos(i)/R_star
	:param P: orbital period (optional, default P=1)
	:returns: Total transit duration in the same units as P.
	"""

	return P*math.asin(r*math.sqrt( ((1+k)**2-b**2) / (1-b**2*r**2) ))/math.pi


###################################################
# Fortran conversions
###################################################
@numba.njit
def sign(a,b) : 
    if b >= 0.0 : return abs(a)
    return -abs(a)

###################################################
# Brent minimisation
###################################################
@numba.njit
def brent(x1,x2,  k, b, width):
    # pars
    tol = 1e-8
    itmax = 100
    eps = 1e-5

    a = x1
    b = x2
    c = 0.
    d = 0.
    e = 0.
    fa = transit_width(a, k, b,1) - width
    fb = transit_width(b, k, b,1) - width

    fc = fb

    for iter in range(itmax):
        if (fb*fc > 0.0):
            c = a
            fc = fa
            d = b-a
            e=d   

        if (abs(fc) < abs(fb)):
            a = b
            b = c
            c = a
            fa = fb
            fb = fc
            fc = fa

        tol1 = 2.0*eps*abs(b)+0.5*tol
        xm = (c-b)/2.0
        if (abs(xm) <  tol1 or fb == 0.0) : return b

        if (abs(e) > tol1 and abs(fa) >  abs(fb)):
            s = fb/fa
            if (a == c):
                p = 2.0*xm*s
                q = 1.0-s
            else:
                q = fa/fc
                r = fb/fc
                p = s*(2.0*xm*q*(q-r)-(b-a)*(r-1.0))
                q = (q-1.0)*(r-1.0)*(s-1.0)
            
            if (p > 0.0) : q = - q
            p = abs(p)
            if (2.0*p < min(3.0*xm*q-abs(tol1*q),abs(e*q))):
                e = d
                d = p/q
            else:
                d = xm
                e = d
        else:
            d = xm
            e = d   

        a = b
        fa = fb      
         
        if( abs(d) > tol1) : b = b + d
        else : b = b + sign(tol1, xm)

        fb = transit_width(b, k, b,1) - width
    return 1



if __name__=='__main__':
		# First, parse the args
	args = parser.parse_args()

	# First, load the lightcurve
	try    :	time, mag, mag_err = np.loadtxt(args.filename).T
	except :	time, mag, mag_err, flux, flux_err = np.loadtxt(args.filename).T
	#finally:	raise IOError('Unable to find or read file: {:}'.format(args.filename))

	args.filename = args.filename[3:]
	incl = 180*np.arccos(args.radius_1*args.b)/np.pi

	# Now sort the lightcurve and ensure f32
	sort = sorted(zip(time, mag, mag_err))
	time = np.array([i[0] for i in sort], dtype = np.float64)
	mag = np.array([i[1] for i in sort], dtype = np.float64)
	mag_err = np.array([i[2] for i in sort], dtype = np.float64)

	# Now calculate a weighted mean
	weighted_mean = np.median(mag) 
	mag = mag - weighted_mean 

	# Calculate the reference transit width 
	ref_width = transit_width(args.radius_1, args.k, args.b, args.period) 
	print('Transit width = {:.2f} hrs'.format(ref_width*24))

	periods = np.linspace(args.period_low, args.period_high, 300)

	periods = period_grid(
		1.,
		1.,
		time[-1]-time[0],
		period_min=args.period_low,
		period_max=args.period_high,
		oversampling_factor=args.oversampling_factor,
		n_transits_min=1,)

	#periods = periods[::5]

	phase_widths = ref_width / periods
	chi_max = np.empty_like(periods)
	chi_max_pos = np.zeros(periods.shape[0])

	reference_loglike = lc(time,  mag,  mag_err,  J=0, incl = 40., zp=weighted_mean)
	print('Ref loglike : ', reference_loglike)

	for i in tqdm(range(periods.shape[0])):
		# First, phase 
		phase = phaser(time,0, periods[i])

		# Now sort 
		sort = sorted(zip(phase, mag, mag_err))
		phase = np.array([p[0] for p in sort])
		mag_ = np.array([p[1] for p in sort])
		mag_err_ = np.array([p[2] for p in sort])

		phase = np.concatenate((phase-1, phase))
		mag_ = np.concatenate((mag_, mag_))
		mag_err_ = np.concatenate((mag_err_, mag_err_))


		# Now calculate radius_1 to match the transit width 
		radius_1 = brent(0.0001,0.8,  args.k, args.b, ref_width/periods[i])
		#print('Period {:} radius_1 {:}'.format(periods[i], radius_1))


		#plt.scatter(phase, mag_, c='k', s=5, alpha = 0.04)

		#plt.scatter(phase, mag_, c='k')
		phase_, mag__, mag_err__ = lc_bin(phase, mag_, 0.5/24/periods[i])
		phase_trials = np.arange(-1,1, 0.5/24/periods[i])
		chis = np.empty_like(phase_trials)
		#plt.scatter(phase_, mag__, c='r')
		#plt.show()
		
		reference_loglike = lc(phase_, mag__, mag_err__, J=0, t_zero = 0., radius_1=radius_1, k = args.k, incl=40. )
		for j in range(phase_trials.shape[0]) : chis[j] = lc(phase_, mag__, mag_err__, J=0, t_zero = phase_trials[j], radius_1=radius_1, k = args.k, incl=incl) - reference_loglike
		
		best_idx = np.argmax(chis)
		chi_max[i] = chis[best_idx]
		chi_max_pos[i] = phase_trials[best_idx]


		#plt.plot(phase_trials, chis)
		#plt.axvline(chi_max_pos[i])
		#plt.show()

	best_idx= np.argmax(chi_max)
	best_period = periods[best_idx]

	f, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize = (10,5))
	ax1.plot(periods, chi_max)
	ax1.set_ylabel('$\Delta \mathcal{L}$')
	ax1.set_xlabel('Period [d]')


	print('best = {:}'.format(best_period))
	ax1.axvline(best_period, ls='--', c='k')


	# First, phase 
	phase = phaser(time,0, best_period)

	# Now sort 
	sort = sorted(zip(phase, mag, mag_err))
	phase = np.array([p[0] for p in sort])
	mag_ = np.array([p[1] for p in sort])
	mag_err_ = np.array([p[2] for p in sort])

	phase = np.concatenate((phase-1, phase))
	mag_ = np.concatenate((mag_, mag_))
	mag_err_ = np.concatenate((mag_err_, mag_err_))
	phase_, mag__, mag_err__ = lc_bin(phase, mag_, 0.5/24/best_period)

	ax2.scatter(phase,mag_, c='k', s=10 )
	ax2.scatter(phase_,mag__, c='r', s=10 )
	std = np.std(mag__)
	median = np.median(mag__)
	ax2.set_ylim(np.max(mag__)+3*std, np.min(mag__)-3*std)

	ax2.set_xlabel('Phase')
	ax2.set_ylabel('Mag')

	best_idx = np.argmax(mag__)
	center = phase_[best_idx]

	best_idx = np.argmax(chi_max)
	center = chi_max_pos[best_idx]
	if center < 0 : center = center + 1


	best_radius_1 = brent(0.0001,0.8,  args.k, args.b, ref_width/best_period)
	phase = np.linspace(-1,1,1000)
	plt.plot(phase, -2.5*np.log10(lc(phase, t_zero = center, radius_1 = best_radius_1, k = args.k, incl=incl)), 'b')

	ax2.set_xlim(center - ref_width/best_period,   center + ref_width/best_period)
	plt.savefig('mbls.pdf')
	plt.show()