#!python

import numba, numba.cuda 
from bruce.binarystar import lc
import numpy as np, sys, os, math
import argparse
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import  sem
from scipy.signal import find_peaks

# LC bin 
def lc_bin(time, flux, bin_width):
	'''
	Function to bin the data into bins of a given width. time and bin_width 
	must have the same units
	'''

	edges = np.arange(np.min(time), np.max(time), bin_width)
	dig = np.digitize(time, edges)
	time_binned = (edges[1:] + edges[:-1]) / 2
	flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
	err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
	time_bin = time_binned[~np.isnan(err_binned)]
	err_bin = err_binned[~np.isnan(err_binned)]
	flux_bin = flux_binned[~np.isnan(err_binned)]   

	return time_bin, flux_bin, err_bin


# Argument parser
parser = argparse.ArgumentParser('templatematch')
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename from which to template search')

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=10, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.2, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.2, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-f', 
                    '--light_3',
                    help='The third light.',
                    default=0.0, type=float) 

parser.add_argument('--gpu', action="store_true", default=False)


def transit_width(r, k, b, P=1):
	"""
	Total transit duration.
	See equation (3) from Seager and Malen-Ornelas, 2003ApJ...585.1038S.
	:param r: R_star/a
	:param k: R_planet/R_star
	:param b: impact parameter = a.cos(i)/R_star
	:param P: orbital period (optional, default P=1)
	:returns: Total transit duration in the same units as P.
	"""

	return P*math.asin(r*math.sqrt( ((1+k)**2-b**2) / (1-b**2*r**2) ))/math.pi



if __name__=="__main__":
	# First, parse the args
	args = parser.parse_args()

	# First, load the lightcurve
	try    :	time, mag, mag_err = np.loadtxt(args.filename).T
	except :	time, mag, mag_err, flux, flux_err = np.loadtxt(args.filename).T
	#finally:	raise IOError('Unable to find or read file: {:}'.format(args.filename))

	args.filename = args.filename[3:]

	# Now sort the lightcurve and ensure f32
	sort = sorted(zip(time, mag, mag_err))
	time = np.array([i[0] for i in sort], dtype = np.float64)
	mag = np.array([i[1] for i in sort], dtype = np.float64)
	mag_err = np.array([i[2] for i in sort], dtype = np.float64)

	# Now calculate a weighted mean
	weighted_mean = np.median(mag)

	# Now we need to pad the data set using the standard deviation of the first 12 hours of data
	time_start_pad = time[time < (time[0] + 0.5)]  - 0.5
	mag_start_pad = np.random.normal(weighted_mean, np.std(mag), time_start_pad.shape[0])                  # mag[time < (time[0] + 0.5)]
	mag_err_start_pad = np.random.uniform(mag_err.min(), mag_err.max(), time_start_pad.shape[0]) # mag_err[time < (time[0] + 0.5)]

	time_end_pad = time[time > (time[-1] - 0.5)] + 0.5
	mag_end_pad = np.random.normal(weighted_mean, np.std(mag), time_end_pad.shape[0])                      # mag[time > (time[-1] - 0.5)]
	mag_err_end_pad = np.random.uniform(mag_err.min(), mag_err.max(), time_end_pad.shape[0])   # mag_err[time > (time[-1] - 0.5)]

	time = np.concatenate((time_start_pad, time, time_end_pad))
	mag = np.concatenate((mag_start_pad, mag, mag_end_pad))
	mag_err = np.concatenate((mag_err_start_pad, mag_err, mag_err_end_pad))

	# now bin to 30 mins 
	#_time, _mag, _mag_err = lc_bin(time, mag, 0.1/24)

	# Now Create time axis for models to be sampled at. 
	# This is required to catch ingress/egress events at the star/end of nights
	dt = 0.1/24
	time_models = np.arange(time[0], time[-1], dt)
	dChi = np.empty(time_models.shape[0])

	# Now save the figure
	plt.figure(figsize=(15,5))
	plt.scatter(time, mag, c='k', s=10)
	plt.axhline(weighted_mean, ls='--', c='r')
	plt.gca().invert_yaxis()
	plt.xlabel('Time')
	plt.ylabel('Mag')
	plt.savefig('{:}_weighted_mean.png'.format(args.filename.split('.')[0]))
	plt.close()

	# Now put onto the GPU if needed and create the relevant arrays
	if args.gpu:
		time_    = numba.cuda.to_device(time)
		mag_     = numba.cuda.to_device(mag)
		mag_err_ = numba.cuda.to_device(mag_err)
		llike_    = numba.cuda.to_device(np.empty(time.shape[0], dtype = np.float64))
		threads_per_block = 256
		blocks = int(np.ceil(time.shape[0]/threads_per_block))

	# Now get the transit width
	width = transit_width(args.radius_1, args.k, args.b, P=args.period)
	print('Transit width = {:.2f} hrs'.format(width * 24.))

	# Now create the template
	incl = 180*np.arccos(args.radius_1*args.b)/np.pi
	time_template = np.linspace(-width, width, 100)
	mag_template = weighted_mean - 2.5*np.log10(lc(time_template, period = args.period, radius_1 = args.radius_1, k = args.k, incl = incl))
	template_depth = np.max(mag_template) - np.min(mag_template)

	plt.figure(figsize=(15,5))
	plt.plot(time_template, mag_template, 'r')
	plt.xlabel('Time from epoch [d]')
	plt.ylabel('Mag')
	plt.gca().invert_yaxis()
	plt.savefig('{:}_template.png'.format(args.filename.split('.')[0]))
	plt.close()

	# Now get the reference logl-likliehood
	if args.gpu : reference_loglike = lc(time_, mag_, mag_err_, loglike = llike_, J=0, blocks = blocks, threads_per_block=threads_per_block,incl = 40., zp=weighted_mean, gpu=1, ld_law_1=-2)
	else        : reference_loglike = lc(time,  mag,  mag_err,  J=0, incl = 40., zp=weighted_mean, gpu=0)
	print('Reference log-likliehood = {:.2f}'.format(reference_loglike))

	# Now enter the main loop
	if args.gpu:
		for i in tqdm(range(time_models.shape[0])) : dChi[i] = lc(time_, mag_, mag_err_, loglike = llike_, J=0, blocks = blocks, threads_per_block=threads_per_block,incl = incl, zp=weighted_mean, gpu=1,
										radius_1 = args.radius_1, k = args.k, t_zero = time_models[i], period = args.period, ld_law_1=-2) - reference_loglike
	else :
		for i in tqdm(range(time_models.shape[0])) : dChi[i] = lc(time, mag, mag_err, J=0,incl = incl, zp=weighted_mean, radius_1 = args.radius_1, k = args.k, t_zero = time_models[i], period = args.period, ld_law_1=-2)

	height=100
	f, (ax1,ax2) = plt.subplots(nrows=2, ncols=1, figsize=(15,8), sharex=True)
	ax1.plot(time_models, dChi, c='k')
	ax1.set_ylim(0, None)
	ax1.set_ylabel(r'$\Delta \mathcal{L}$')
	ax1.grid()
	ax1.axhline(height, c='k', ls='--')
	ax2.scatter(time, mag, c='k', s=10)
	ax2.invert_yaxis()
	ax2.set_ylabel('Mag')
	ax2.set_xlabel('Time [jd]')
	ax2.grid()

	# Now run a peak finding algorithm
	peaks, meta = find_peaks(dChi, distance=12/dt, height = height)
	print('Number of peaks : {:}'.format(len(peaks)))
	if len(peaks) > 5 :
		print('\tKeeping only the best 5')
		height_best_idx = meta['peak_heights'].argsort()[-5:][::-1]
		peaks = peaks[height_best_idx]

	ax1.plot(time_models[peaks], dChi[peaks], "x")
	plt.savefig('{:}_dChi.png'.format(args.filename.split('.')[0]))


	for i in range(len(peaks)):
		t_temp = time[(time > (time_models[peaks[i]] - min(width,0.4))) & (time < (time_models[peaks[i]] + min(width,0.4)))] - time_models[peaks[i]]
		m_temp = mag[(time > (time_models[peaks[i]] - min(width,0.4))) & (time < (time_models[peaks[i]] + min(width,0.4)))]
		f = plt.figure(figsize = (15,5))
		plt.scatter(t_temp, m_temp, s=10,c='k')
		plt.plot(time_template, mag_template, 'r')
		t_temp_bin, m_temp_bin, _ = lc_bin(t_temp, m_temp, dt)
		plt.plot(t_temp_bin, m_temp_bin, 'b-', alpha = 0.5)
		#plt.gca().invert_yaxis()
		plt.ylim(weighted_mean +2.5*template_depth,weighted_mean-1.5*template_depth)
		plt.xlabel('Time - {:}'.format(time_models[peaks[i]]))
		plt.ylabel('Mag')
		plt.grid()
		plt.savefig('{:}_peak{:}.png'.format(args.filename.split('.')[0], i+1))
		plt.close()
