#!/usr/bin/env python
import numpy as np 
import argparse 
from photutils import aperture_photometry, CircularAperture, CircularAnnulus
from astropy.visualization import simple_norm
from astropy.io import fits
import matplotlib.pyplot as plt
import glob

import numpy as np
from astropy.io import fits
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from numba import autojit

def Gauss(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

@autojit(nopython=True)
def count_flux(image, x,y,r1,r2,r3):
    flux1 = 0.
    flux2 = 0.
    flux3 = 0.
    tmp1 = 0.0
    r1 = r1*r1
    r2 = r2*r2
    r3 = r3*r3

    for i in range(image.shape[1]):
        for j in range(image.shape[0]):
            tmp1 = i*i + j*j
            if (tmp1 < r1):
                flux1 = flux1 + image[j,i]
            if (tmp1 < r2):
                flux2 = flux2 + image[j,i]
            if (tmp1 < r3):
                flux3 = flux3 + image[j,i]

    return flux1,flux2,flux3





def find_offset(im1, im2, xlim, ylim , method = 'gaussian'):
    ############################
    # Note original shapes
    ############################
    orig_shape1 = np.array(im1.shape)
    orig_shape2 = np.array(im2.shape)

    im1 = im1.astype(np.float64)
    im2 = im2.astype(np.float64)

    #####################
    # Preliminary checks
    #####################
    if xlim[0] > xlim[1]:
        raise ValueError('The first x limit cannot exceed the second.')
    if ylim[0] > ylim[1]:
        raise ValueError('The first y limit cannot exceed the second.')

    ####################
    # First trim images
    ####################
    im1 = im1[xlim[0]:xlim[1], ylim[0]:ylim[1] ]
    im2 = im2[xlim[0]:xlim[1], ylim[0]:ylim[1] ]

    ##############################################################
    # get rid of the averages, otherwise the results are not good
    ##############################################################
    im1 -= np.mean(im1)
    im2 -= np.mean(im2)

    ##########################################################################
    # calculate the correlation image; note the flipping of onw of the images
    ##########################################################################
    corr_img =  fftconvolve(im1, im2[::-1,::-1], mode='same')


    if method == 'best_pixel':
        ##############################################################################
        # Now unravel the offset by finding the peak
        # Note that this is relative to ratio of the original image and the cut image
        ###############################################################################
        best_corr = np.array(np.unravel_index(np.argmax(corr_img), corr_img.shape)) # the best coordinates
        dx, dy = (best_corr[::-1] - np.array(corr_img.shape)[::-1]/2) * np.array([-1,-1])
        dxe, dye = 0.5,  0.5
        #return (best_corr[::-1] - np.array(corr_img.shape)[::-1]/2) * np.array([-1,-1])
        return dx,dxe,dy,dye

    elif method == 'gaussian_fit':
        ######################################################
        # Get the mean of the CCF image in y and x direction
        ######################################################
        corr_img_y , corr_img_x = corr_img.mean(axis=0), corr_img.mean(axis=1)

        corr_img_x  =corr_img_x - np.min(corr_img_x)
        corr_img_x = corr_img_x/ corr_img_x.max()

        corr_img_y  =corr_img_y - np.min(corr_img_y)
        corr_img_y = corr_img_y / corr_img_y.max()
        

        #######################################
        # Now get the pixel it socrresponds to
        #########################################
        corr_img_yy , corr_img_xx = np.arange(corr_img.shape[1]) - np.array(corr_img.shape[1])/2  ,  np.arange(corr_img.shape[0]) - np.array(corr_img.shape[0])/2 


	
        #####################
        # do y first
        #####################  
        n = corr_img_xx.shape[0]
        n_low, n_high = np.int(np.floor(0.35*n)), np.int(np.floor(0.7*n))
        corr_img_x = corr_img_x[n_low:n_high]
        corr_img_xx = corr_img_xx[n_low:n_high]

        mean = np.sum(corr_img_xx* corr_img_x) / np.sum(corr_img_x)
        sigma = np.sqrt(np.sum(corr_img_x * (corr_img_xx - mean)**2) / np.sum(corr_img_x))
        try:
            popt,pcov = curve_fit(Gauss, corr_img_xx, corr_img_x, p0=[max(corr_img_x), 0, 1])
            perr = np.sqrt(np.diag(pcov))
            dx, dxe = -popt[1], perr[1]
        except:
            dx,dxe = 0,99

        '''
        plt.close()
        plt.plot(corr_img_xx,Gauss(corr_img_xx,*popt),'r:',label='fit')
        plt.plot(corr_img_xx,corr_img_x,'b')
        plt.show()
        plt.sleep(1)
        '''



        #####################
        # do x next
        ##################### 
        n = corr_img_yy.shape[0]
        n_low, n_high = np.int(np.floor(0.35*n)), np.int(np.floor(0.7*n))
        corr_img_y = corr_img_y[n_low:n_high]
        corr_img_yy = corr_img_yy[n_low:n_high]
    
        mean = np.sum(corr_img_yy* corr_img_y) / np.sum(corr_img_y)
        sigma = np.sqrt(np.sum(corr_img_y * (corr_img_y - mean)**2) / np.sum(corr_img_y))
        try:
            popt,pcov = curve_fit(Gauss, corr_img_yy, corr_img_y, p0=[max(corr_img_y), 0, 1])
            perr = np.sqrt(np.diag(pcov))
            dy, dye = -popt[1], perr[1]
        except:
            dy,dye = 0,99

        '''
        plt.close()
        plt.plot(corr_img_yy,Gauss(corr_img_yy,*popt),'r:',label='fit')
        plt.plot(corr_img_yy,corr_img_y,'b')
        plt.show()
        plt.sleep(1)
        '''
        ##################
        # Ooutput result
        ##################
        #print('dx: {:.3f} +/- {:.3f}, dy: {:.3f} +/- {:.3f}'.format(dx,dxe,dy,dye))

        return dy,dye,dx,dxe

    else:
        msg='''
Method choice not understood.

Available choices are:

1) best_pixel

2) gaussian_fit
'''
        raise ValueError(msg)

           
        
    

        






# Welcom messages
welcome_message = '''---------------------------------------------------
-                   quickphot V.1                 -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to search for transit events in ground-based photometry.'''

parser = argparse.ArgumentParser('tls', description=description)

parser.add_argument('-a', 
                    '--reference_image',
                    help='The reference image.', 
                    type=str, default='a.fits')

parser.add_argument('-aa', 
                    '--files',
                    help='The files.', 
                    type=str, default='*.fits')

parser.add_argument('-b', '--target', 
                    help='The target star in pixel position e.g. --target x y', 
                    nargs='+',
                    default=[0,0], type = float)

parser.add_argument('-c', '--comparison', 
                    help='The comparison star in pixel position e.g. --comparison x y', 
                    nargs='+',
                    default=[0,0], type = float)

parser.add_argument('-d', '--check', 
                    help='The check star in pixel position e.g. --comparison x y', 
                    nargs='+',
                    default=[0,0], type = float) 
        
parser.add_argument('-e', '--r1', 
                    help='The inner radius of the aperture.', 
                    type = float, default=5.) 

parser.add_argument('-f', '--skyin', 
                    help='The outer radius of the aperture.', 
                    type = float, default=8.) 

parser.add_argument('-g', '--skyout', 
                    help='The sky aperture', 
                    type = float, default=10.) 

parser.add_argument('-i', '--trialxlim', 
                    help='The xlim for the field e.g. --trialxlim 0 10.', 
                    nargs='+',
                    default=[None,None])     


parser.add_argument('-j', '--trialylim', 
                    help='The ylim for the field e.g. --trialylim 0 10.', 
                    nargs='+',
                    default=[None,None])    

parser.add_argument('--image_align', action="store_true", default=False)


parser.add_argument('--trial', action="store_true", default=False)


def p_(x): return np.ones(x.shape[0])

if __name__=="__main__":
    # print welcome message
    print(welcome_message)

    # Parse arguments 
    args = parser.parse_args()

    # Verbose the reference image and coordinates
    print('Files glob: ' ,args.files)
    #target = [float(args.target[0]), float(args.target[1])]
    #comparison = [float(args.comparison[0]), float(args.comparison[1])]
    #check = [float(args.check[0]), float(args.check[1])]

    print('\tReference image : ', args.reference_image)
    print('\tTarget:     x = {:.3f}, y = {:.3f}'.format(*args.target))
    print('\tComparison: x = {:.3f}, y = {:.3f}'.format(*args.comparison))
    print('\tCheck:      x = {:.3f}, y = {:.3f}'.format(*args.check))
    print('---------------------------------------------------')
    
    # Define the perture positions 
    positions = np.array([args.target, args.comparison, args.check])
    positions_reference = np.array([args.target, args.comparison, args.check])

    aperture = CircularAperture(positions, r=args.r1)
    annulus_aperture = CircularAnnulus(positions, r_in=args.skyin, r_out=args.skyout)
    apers = [aperture, annulus_aperture]

    # Load the reference image 
    reference_image = fits.open(args.reference_image)[1].data
    mask = fits.open(args.reference_image)[3].data.astype(np.bool)
    reference_image[mask] = np.random.normal(np.median( reference_image[~mask]), 3, reference_image[mask].shape[0])
    name = fits.open(args.reference_image)[1].header['OBJECT']
    norm = simple_norm(reference_image, 'sqrt', percent=99)

    # Check for trial 

    plt.imshow(reference_image, norm=norm,origin='lower', aspect='auto')
    aperture.plot(color='white', lw=2)
    annulus_aperture.plot(color='red', lw=2)
    plt.xlabel('X pix')
    plt.ylabel('Y pix')
    plt.text(args.target[0]+10, args.target[1]+10, s='Target')
    plt.text(args.comparison[0]+10, args.comparison[1]+10, s='Comparison')
    plt.text(args.check[0]+10, args.check[1]+10, s='Check')

    if None not in args.trialxlim : 
        plt.xlim([float(i) for i in args.trialxlim])
    if None not in args.trialylim : 
        plt.ylim([float(i) for i in args.trialylim])
    plt.gca().invert_xaxis()
    plt.savefig('{:}_field.pdf'.format(name))

    if args.trial:
        plt.show() 
        exit()
    plt.close()
    

    files = glob.glob(args.files)
    print('Number of files found : {:}'.format(len(files)))
    

    flux1, flux2, flux3, X, Y, sky = [],[],[], [],[],[]


    f, ax = plt.subplots()

    for i in range(len(files)):
        ax.clear()

        # First, align and modify aperture
        if args.image_align:
            image = fits.open(files[i])[1].data
            mask = fits.open(files[i])[3].data.astype(np.bool)
            image[mask] = np.random.normal(np.median( reference_image[~mask]), 3, reference_image[mask].shape[0])
            norm = simple_norm(image, 'sqrt', percent=99)

            dx,dxe,dy,dye = find_offset(reference_image, image, xlim=[1000, 3000], ylim = [2000, 3000], method='gaussian_fit')
            if (dx < -50) or (dx > 50) : dx = 0. 
            if (dy < -50) or (dy > 50) : dy = 0. 
            #print(dx,dxe, dy, dye)
            positions = np.array([[args.target[0]+ dx, args.target[1]+ dy] , [args.comparison[0]+ dx, args.comparison[1]+ dy], [args.check[0]+ dx, args.check[1]+ dy]])
            aperture = CircularAperture(positions, r=args.r1)
            annulus_aperture = CircularAnnulus(positions, r_in=args.skyin, r_out=args.skyout)
            apers = [aperture, annulus_aperture]

            X.append(dx)
            Y.append(dy)
            #print(files[i], dx, dy)
        else:
            X.append(0.)
            Y.append(0.)

        ax.imshow(image)
        aperture.plot(color='blue', lw=1.5, alpha=0.5)
        annulus_aperture.plot(color='red', lw=1.5, alpha=0.5)
        ax.set_xlim(args.target[0] - 30, args.target[0] + 30 )
        ax.set_ylim(args.target[1] - 30, args.target[1] + 30 )

        plt.pause(0.05)

        phot_table = aperture_photometry(image, apers)
        bkg_mean = phot_table['aperture_sum_1'] / annulus_aperture.area
        bkg_sum = bkg_mean * aperture.area
        final_sum = phot_table['aperture_sum_0'] - bkg_sum
        phot_table['residual_aperture_sum'] = final_sum
        phot_table['residual_aperture_sum'].info.format = '%.8g'  # for consistent table output
        flux1.append(phot_table['residual_aperture_sum'][0])
        flux2.append(phot_table['residual_aperture_sum'][1])
        flux3.append(phot_table['residual_aperture_sum'][2])



    fig, axs = plt.subplots(3, 2, figsize=((10,15)))
    axs[0,0].scatter(range(len(flux1)), np.array(flux1)/np.array(flux2), c='k', s=10) 
    try:
        z = np.polyfit(range(len(flux1)), np.array(flux1)/np.array(flux2), 2)
        p = np.poly1d(z)
    except : p = p_
    axs[0,0].plot(range(len(flux1)), p(np.arange(len(flux1))), 'r') 

    axs[0,1].scatter(range(len(flux1)), np.array(flux1)/np.array(flux2) /p(np.arange(len(flux1))) , c='k', s=10)
    axs[0,1].set_title('RMS : {:.0f} ppm'.format(1e6*np.std(np.array(flux1)/np.array(flux2) /p(np.arange(len(flux1))))))
    axs[1,0].scatter(range(len(flux1)), np.array(flux1)/np.array(flux3), c='k', s=10)   
    axs[2,0].scatter(range(len(flux1)), np.array(flux2)/np.array(flux3), c='k', s=10)   

    axs[1,1].scatter(range(len(flux1)), X, c='k', s=10)
    axs[2,1].scatter(range(len(flux1)), Y, c='k', s=10)


    axs[0,0].set_ylabel('target / comparison')
    axs[1,0].set_ylabel('target / check ')
    axs[2,0].set_ylabel('comparison / check')
    axs[1,1].set_ylabel('X [pix]')
    axs[2,1].set_ylabel('Y [pix]')
    axs[2,0].set_xlabel('Frame')
    axs[2,1].set_xlabel('Frame')
    axs[0,0].set_title(name)



    plt.savefig('{:}_photometry.pdf'.format(name))
    plt.show()
    





'''
positions = np.transpose((sources['xcentroid'], sources['ycentroid']))  
apertures = CircularAperture(positions, r=4.)  
phot_table = aperture_photometry(image, apertures)  
for col in phot_table.colnames:  
'''