#!/home/sam/anaconda3/bin/python

import matplotlib.pyplot as plt
import numpy as np, os, sys
import argparse
import eleanor
from lightkurve.lightcurve import TessLightCurve
import matplotlib.gridspec as gridspec


# Argarse 
description = '''Retrieve elanor lightcurve for TIC'''
parser = argparse.ArgumentParser('tessget', description=description)
parser.add_argument('-a', 
                    '--tic',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=int)

parser.add_argument('-b', 
                    '--saveplace',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=str,
					 default='.')

def get_mask(time, sector):
	if sector==1:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1325.5) & (time < 1338.3)) | ((time>1339.9) & (time< 1347.0)) | ((time>1350.) & (time< 1353.0)))
	if sector==2:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1345.3) & (time < 1366.9)) | ((time>1368.8) & (time< 1381.3)))
	if sector==3:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1385.7) & (time < 1395.3)) | ((time>1396.4) & (time< 1406.1)))
	if sector==4:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1413) & (time < 1419)) | ((time> 1424.7) & (time <1436.6)))
	if sector==5:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1438) & (time < 1450)) | ((time>1451.7) & (time<1463.9 )))
	if sector==6:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1468.3) & (time < 1477)) | ((time>1478.2) & (time<1490.0 )))
	if sector==7:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1491.7) & (time < 1503.0)) | ((time>1505.0) & (time<1416.0 )))
	if sector==8:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1517.5) & (time < 1529.0)) | ((time>1530.5) & (time<1531.7 )) | ((time>1535.2) & (time<1541.7)))
	if sector==9:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1543.8) & (time < 1555.5)) | ((time>1557.1) & (time<1568.4 )))
	if sector==10:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1571.0) & (time < 1581.7)) | ((time>1584.8) & (time<1595.6 )))
	if sector==11:
		condition = (time>time[10]) & (time<time[-10]) & (((time >1600.0) & (time < 1609.6)) | ((time>1614.3) & (time<1623.8 ))) 

	return condition

if __name__=="__main__":
	# First, parse the args 
	args = parser.parse_args()

	# Then query elanor
	star = eleanor.multi_sectors(tic=args.tic, sectors='all', tc=True)

	master_time = []
	master_flux = []

	# Create the plot
	#gs = gridspec.GridSpec(len(star), 4)

	for i in range(len(star)):
		# get the object
		object = star[i]
		data = eleanor.TargetData(object, height=15, width=15, bkg_size=31, do_psf=True, do_pca=True)

		# Now remove the background
		data.bkg_subtraction()

		# Now get the quality mask
		q = data.quality == 0

		# Now  plot the aperture
		vis = eleanor.Visualize(data)
		vis.aperture_contour()
		plt.xlabel('X')
		plt.ylabel('Y')
		plt.title('Sector {:}'.format(int(data.header['SECTOR'])))
		plt.savefig('{:}/{:}_sector_{:}_aperture.png'.format(args.saveplace, args.tic, data.header['SECTOR']))
		plt.close()


		# Now get mask 
		mask = get_mask(data.time[q], int(data.header['SECTOR']))

		# Now make tesslightcurve
		s = TessLightCurve(data.time[q][mask], data.corr_flux[q][mask]/np.median(data.corr_flux[q][mask]))
		s = s.flatten()

		# Now add the flux to master time
		for j in range(len(s.time)):
			master_time.append(s.time[j])
			master_flux.append(s.flux[j])

	tmp = np.array([master_time, -2.5*np.log10(master_flux), np.ones(len(master_time))*1e-3]).T
	np.savetxt('{:}/{:}_eleanor_lightcurve.dat'.format(args.saveplace, args.tic), tmp)
