#!python

from astropy.table import Table

import numpy as np
from astroquery.mast import Catalogs
from astroquery.mast import Tesscut
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.io import fits
import matplotlib.pyplot as plt
from tqdm import tqdm
import os,sys
import matplotlib.path
import math

import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia

import os, fnmatch
def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result



# Define a function to simplify the plotting command that we do repeatedly.
def plot_cutout(image):
    """
    Plot image and add grid lines.
    """
    plt.imshow(image, origin = 'lower', cmap = 'gray_r', 
           vmax = np.percentile(image, 96),
           vmin = np.percentile(image,5)) # 

    plt.grid(axis = 'both',color = 'white', ls = 'solid')




def aperture_phot(image, aperture):
    """
    Sum-up the pixels that are in the aperture for one image.
    image and aperture are 2D arrays that need to be the same size.
    
    aperture is a boolean array where True means to include the light of those pixels.
    """
    flux = np.sum(image[aperture])

    return flux

def make_lc(flux_data, aperture):
    """
    Apply the 2d aperture array to the and time series of 2D images. 
    Return the photometric series by summing over the pixels that are in the aperture.
    
    Aperture is a boolean array where True means it is in the desired aperture.
    """
    
    flux = np.array(list (map (lambda x: aperture_phot(x, aperture), flux_data) ) )

    return flux


def phaser(time, t_zero, period) : return ((time - t_zero)/period) - np.floor((time - t_zero)/period)


if __name__=='__main__':
	######################################################
	# Step 1 - find the NOI number in the NOI-key.txt file
	######################################################
	ra, dec = float(sys.argv[1]), float(sys.argv[2])

	#########################################
	# Step 2 - Query the object with TESS cut
	#########################################
	coord = SkyCoord(ra, dec, unit = "deg")
	try:
		hdulist = Tesscut.get_cutouts(coord, int(sys.argv[3]) )
		hdu1 = hdulist[0]
	except: raise ValueError('I failed to ge the TESS cut :(')

	#########################
	# Step 3 - plot the image
	#########################
	image = hdu1[1].data['FLUX'][100]
	wcs = WCS(hdu1[2].header)

	fig = plt.figure(figsize = (8, 8))
	fig.add_subplot(111, projection = wcs)
	plot_cutout(image)

	###############################################################
	# Step 4 - Now we make an aperture based on a contour (argv[3])
	###############################################################
	cs = plt.contour(image, np.array([float(sys.argv[4])]) )
	centerx, centery = image.shape[0]/2., image.shape[1]/2.
	idx = 99
	# Now we have to slect the path that encloses the center
	for i in range(len(cs.collections[0].get_paths() )):
		p = cs.collections[0].get_paths()[i]
		v = p.vertices
		xmin, xmax = np.min(v[:,0]), np.max(v[:,0])
		ymin, ymax = np.min(v[:,1]), np.max(v[:,1])
		if (xmin < centerx) and (xmax > centerx) and (ymin < centery) and (ymax > centery) : idx = i
	if (idx==99):
		print('I could not find a contour which overlapped the target')
		plt.show()
		raise ValueError('No valid contour')
	else : print('Best path is ', idx)

	##################################################
	# Step 5 - now create the mas based on the contour
	##################################################
	p = cs.collections[0].get_paths()[idx]
	v = p.vertices	
		
	x_pixel_nos = v[:,0]
	y_pixel_nos = v[:,1]
	temp_list = []
	for a, b in zip(x_pixel_nos, y_pixel_nos):
		temp_list.append([a, b])

	polygon = np.array(temp_list)
	left = np.min(polygon, axis=1)
	right = np.max(polygon, axis=0)
	x = np.arange(math.ceil(left[0]), math.floor(right[0])+1)
	y = np.arange(math.ceil(left[1]), math.floor(right[1])+1)
	x = np.arange(image.shape[0])
	y = np.arange(image.shape[1])
	xv, yv = np.meshgrid(x, y, indexing='xy')
	points = np.hstack((xv.reshape((-1,1)), yv.reshape((-1,1))))

	path = matplotlib.path.Path(polygon)
	mask = path.contains_points(points)
	mask.shape = xv.shape
	
	plt.imshow(mask, alpha = 0.2, cmap='jet', origin='lower')
	

	###############################################
	# Step 6 - Find the background stars using Gaia
	###############################################
	coord = SkyCoord(ra=ra, dec=dec, unit=(u.degree, u.degree), frame='icrs')
	width = u.Quantity(0.1, u.deg)
	height = u.Quantity(0.1, u.deg)
	'''
	r = Gaia.query_object_async(coordinate=coord, width=width, height=height)

	mag_target = r['phot_g_mean_mag'][0]
	star_mask = (r['phot_g_mean_mag'] < mag_target + 3)
	r = r[star_mask]


	xlims = plt.gca().get_xlim()
	ylims = plt.gca().get_ylim()
	for i in r:
		x, y = wcs.world_to_pixel(SkyCoord(ra=i['ra'], dec=i['dec'], unit=(u.degree, u.degree), frame='icrs'))
		x,y = int(x), int(y)
		plt.plot(x, y, 'r+')
	'''

	##################################
	# Step 7 - Tidy up the field plot
	##################################
	plt.xlabel('RA', fontsize = 12)
	plt.ylabel('Dec', fontsize = 12)

	plt.savefig('Target_field.png')
	#plt.show()




	####################################################################################
	# Step 8 - Calculate the background flux using a mask from the lowest 5 % percentile
	####################################################################################
	bkgAperture = hdu1[1].data['FLUX'][0] < np.percentile(hdu1[1].data['FLUX'][0], 5)
	bkgFlux1 = make_lc(hdu1[1].data['FLUX'], bkgAperture)

	######################
	# Step 9 - plot the LC
	######################
	aperture = mask

	flux1 = make_lc(hdu1[1].data['FLUX'], aperture)
	bkgSubFlux = flux1 - (bkgFlux1 * np.sum(aperture) / np.sum(bkgAperture) )


	mask = (hdu1[1].data['QUALITY'] > 0) #|  ((hdu1[1].data['TIME'] > 1347) & (hdu1[1].data['TIME'] < 1350)) | ((hdu1[1].data['TIME'] > 1352) & (hdu1[1].data['TIME'] < 1358)) | ((hdu1[1].data['TIME'] > 1437.5) & (hdu1[1].data['TIME'] < 1438.5)) | ((hdu1[1].data['TIME'] > 1449.91) & (hdu1[1].data['TIME'] < 1451.85)) | ((hdu1[1].data['TIME'] > 1463.5) & (hdu1[1].data['TIME'] < 1464))

	time1 = hdu1[1].data['TIME']


	time1 = time1[~mask]
	bkgSubFlux = bkgSubFlux[~mask]

	tmp = np.array([time1.tolist(), bkgSubFlux.tolist()]).T 
	np.savetxt('LC.dat', tmp)
	plt.figure(figsize=(15,5))
	plt.scatter(time1, bkgSubFlux, c='k', s=10)
	plt.xlabel('Time')
	plt.ylabel('Flux')
	plt.savefig('Target_lightcurve.png')


	plt.figure(figsize=(15,5))
	plt.scatter(range(len(bkgSubFlux)), bkgSubFlux, c='k', s=10)

	plt.show()
	
