#!python

import matplotlib.pyplot as plt
import numpy as np, os, sys
import argparse
import eleanor
from lightkurve.lightcurve import TessLightCurve
import matplotlib.gridspec as gridspec
from astropy.coordinates import SkyCoord

# Argarse 
description = '''Retrieve elanor lightcurve for TIC'''
parser = argparse.ArgumentParser('tessget', description=description)
parser.add_argument('-a', 
                    '--tic',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=int, default=0)

parser.add_argument('-b', 
                    '--saveplace',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=str,
					 default='.')

parser.add_argument('-c', 
                    '--ra',
                     help='The right ascencion.', type=float,
					 default=-99.)
			
parser.add_argument('-d', 
                    '--dec',
                     help='The declination.', type=float,
					 default=-99.)

parser.add_argument('--forceaperture', action="store_true", default=False, help="Force a custom aperture")
parser.add_argument('--box1', action="store_true", default=False, help="Force a custom aperture")
parser.add_argument('--box2', action="store_true", default=False, help="Force a custom aperture")
parser.add_argument('--box3', action="store_true", default=False, help="Force a custom aperture")



def get_mask(time, sector):
	if sector==1:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1325.5) & (time < 1338.3)) | ((time>1339.9) & (time< 1347.0)) | ((time>1350.) & (time< 1353.0)))
	if sector==2:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1345.3) & (time < 1366.9)) | ((time>1368.8) & (time< 1381.3)))
	if sector==3:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1385.7) & (time < 1395.3)) | ((time>1396.4) & (time< 1406.1)))
	if sector==4:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1413) & (time < 1419)) | ((time> 1424.7) & (time <1436.6)))
	if sector==5:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1438) & (time < 1450)) | ((time>1451.7) & (time<1463.9 )))
	if sector==6:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1468.3) & (time < 1477)) | ((time>1479) & (time<1490.0 )))
	if sector==7:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1491.7) & (time < 1503.0)) | ((time>1505.0) & (time<1416.0 )))
	if sector==8:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1517.5) & (time < 1529.0)) | ((time>1530.5) & (time<1531.7 )) | ((time>1535.2) & (time<1541.7)))
	if sector==9:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1543.8) & (time < 1555.5)) | ((time>1557.1) & (time<1568.4 )))
	if sector==10:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1571.0) & (time < 1581.7)) | ((time>1584.8) & (time<1595.6 )))
	if sector==11:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1600.0) & (time < 1609.6)) | ((time>1614.3) & (time<1623.8 ))) 

	return condition

if __name__=="__main__":
	# First, parse the args 
	args = parser.parse_args()

	# Then query elanor
	if args.tic != 0 : star = eleanor.multi_sectors(tic=args.tic, sectors='all', tc=True)
	else: star = eleanor.multi_sectors(coords = SkyCoord(args.ra, args.dec, unit="deg"), sectors='all', tc=True)

	# mster arrays
	master_time = []
	master_flux = []

	# Create the plot
	#gs = gridspec.GridSpec(len(star), 4)

	for i in range(len(star)):
		# get the object
		object = star[i]
		data = eleanor.TargetData(object, height=15, width=15, bkg_size=31, do_psf=True, do_pca=True)
		
		if args.forceaperture:
			mask = np.zeros(np.shape(data.tpf[0]))
			if args.box1==True : mask[8:10,7:9] = 1
			if args.box2==True : mask[11:15,6:10] = 1
			if args.box3==True : mask[6:12,5:11] = 1

			data.get_lightcurve(aperture=mask)
			corr_flux = eleanor.TargetData.corrected_flux(data, flux=data.raw_flux)
			eleanor.TargetData.pca(data, flux=corr_flux, modes=4)

		# Now remove the background
		data.bkg_subtraction()

		# Now get the quality mask
		q = data.quality == 0

		# Now  plot the aperture
		vis = eleanor.Visualize(data)
		vis.aperture_contour()
		plt.xlabel('X')
		plt.ylabel('Y')
		plt.title('Sector {:}'.format(int(data.header['SECTOR'])))
		plt.savefig('{:}/{:}_sector_{:}_aperture.png'.format(args.saveplace, args.tic, data.header['SECTOR']))
		plt.close()


		# Now get mask 
		mask = get_mask(data.time[q], int(data.header['SECTOR']))

		# Now make tesslightcurve
		s = TessLightCurve(data.time[q][mask], data.corr_flux[q][mask]/np.median(data.corr_flux[q][mask]))
		s = s.flatten()

		# Now add the flux to master time
		for j in range(len(s.time)):
			master_time.append(s.time[j])
			master_flux.append(s.flux[j])

	# Clean up to avoid memory use
	for i in range(len(star)):
		object = star[i]
		os.system('rm {:}'.format(object.postcard_path))

	tmp = np.array([np.array(master_time) + 2457000.0, -2.5*np.log10(master_flux), np.ones(len(master_time))*1e-3]).T
	np.savetxt('{:}/{:}_eleanor_lightcurve.dat'.format(args.saveplace, args.tic), tmp)
