#!python

"""
STARTING OVER WITH A COMPLETE REWRITE

PROCESS:
    Take in which object to analyse
    Go find its folder and phot output

    For each action
        For each aperture size
            For each comparison star
                work out the RMS/quality
                flag any potentially bad stars

            Make the light curve
            Analyse the RMS (OOT if transit) in data
        Make a final light curve from the best aperture and comparisons
"""
import sys
import logging
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.table import Table, Column
import matplotlib.gridspec as gridspec
from scipy.stats import sem 
from astropy.time import Time
import warnings
from scipy.signal import find_peaks
warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore')

plt.rcParams.update({'font.size': 7})


def round_of_rating(number):
    """Round a number to the closest half integer.
    >>> round_of_rating(1.3)
    1.5
    >>> round_of_rating(2.6)
    2.5
    >>> round_of_rating(3.0)
    3.0
    >>> round_of_rating(4.1)
    4.0"""

    return round(number * 2) / 2


def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin



# TODO Add method of rejecting frames rather than stars if many stars have failed phot
# TODO Add docstrings
# pylint: disable = invalid-name
# pylint: disable = redefined-outer-name
# pylint: disable = no-member
# pylint: disable = too-many-locals
# pylint: disable = too-many-arguments
# pylint: disable = unused-variable
# pylint: disable = line-too-long

# default plot settings
plt.rc('legend', **{'fontsize':10})

if __name__ == "__main__":
    # First, let's get the actions for this object 
    # First, let's get the actions for this object 
    actions = np.unique([i.split('.')[0] for i in glob.glob('*.phot*')])
    actions.sort()
    number_per_action = [len(glob.glob('{:}.phot*'.format(i))) for i in actions]
    
    if actions.shape[0] ==0 : 
        raise NameError('No actions to reduce.')
        exit()
    else:
        print('{:} actions found:'.format(actions.shape[0]))
        for i in range(actions.shape[0]) : print('\t{:} with {:>3} apertures'.format(actions[i], number_per_action[i]))

    try: 
        aperture, last_action = np.genfromtxt('.monoaperture', dtype =str)
    except : 
        aperture, last_action = -np.inf, -np.inf
    if float(last_action) >= float(actions[-1]) : 
        print('Up to date')
        exit() 
    else:
        print('Continuing since last action was {:} and current action is {:}'.format(last_action, actions[-1]))



    print('\n\nSearching for inhomegenous reference stars where the field has been reset')
    aperture_to_use = '3.5'

    f = plt.figure()
    phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
    X_ref = phot_table[50, 8::7][:10]
    Y_ref = phot_table[50, 9::7][:10]
    diff_=[0]
    diff_ref_regions = []
    idx_ref = 0
    diff_thresh = 70000

    i = 1
    while i < len(actions):
        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
        diff =  np.sum(np.abs(phot_table[0, 8::7][:10] -X_ref)) +  np.sum(np.abs(phot_table[0, 9::7][:10] -Y_ref))
        diff_.append(diff)
        X_ref = phot_table[50, 8::7][:10]
        Y_ref = phot_table[50, 9::7][:10]
        if (diff > diff_thresh) or (abs(number_per_action[i] - number_per_action[i-1]) > 1):
            diff_ref_regions.append([idx_ref, i-1])
            idx_ref = i 
            i +=1 # add to skip followin hit
            phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i-2], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
            X_ref = phot_table[50, 8::7][:10]
            Y_ref = phot_table[50, 9::7][:10]
        i+=1 

        if i == len(actions) : diff_ref_regions.append([idx_ref, i-1])

    if len(diff_ref_regions)==0 : diff_ref_regions = [[0, len(actions)]]
    print('\tFound {:} regions with different refernece stars... '.format(len(diff_ref_regions)))
    if len(diff_ref_regions) > 1 : 
        for i in range(len(diff_ref_regions)) : print('\t\tBetween actions ', actions[diff_ref_regions[i][0]],' and ',actions[diff_ref_regions[i][1]])
    else : print('All appear to be the same')
    plt.plot(range(len(diff_)), diff_)
    plt.axhline(np.median(diff_))
    if len(diff_ref_regions) > 1 :
        for i in range(len(diff_ref_regions)) : plt.axvline(diff_ref_regions[i][1], ls='--')

    plt.savefig('ref_XY_diff.png')
    plt.close()


    print('\nSearching for bad comparison stars in the {:} pixel aperture for {:} reference images'.format(aperture_to_use, len(diff_ref_regions)))

    ###########################################################
    # Section 1
    # Before selecting an aperture, we need to look for bad
    # comparison stars. To do this, we will look at the 3.5 
    # pixel aperture only. 
    ############################################################

    # Let's load the first action to see how many comparison stars there are
    phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[0], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
    number_of_coparison_stars = 10 #(phot_table.shape[1] - (7+7)) // 7 

    # Now create the bad comparison star mask
    bad_star_mask = np.zeros((len(diff_ref_regions), actions.shape[0], number_of_coparison_stars))

    diff_ref_regions_idx = None 

    # Create the bad night flag 
    bad_night = np.zeros(actions.shape[0], dtype = np.bool)

    # Now cycle apertures
    for i in range(actions.shape[0]):
        # First, we need to find which diff_ref_region with aperture belogs to
        for j in range(len(diff_ref_regions)):
            if (i >= diff_ref_regions[j][0]) and (i< diff_ref_regions[j][1]) : 
                diff_ref_regions_idx = j 
                break
        
        # Get the number of apertures
        apertures = [j.split('.')[1][4:] +'.' +  j.split('.')[2] for j in glob.glob('{:}.phot*'.format(actions[i]))]
        apertures.sort(key=float)

        # Now set up the figure to plot 
        fig1, ax = plt.subplots(nrows = number_of_coparison_stars, ncols=1, figsize=(5,50))

        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 


        for j in range(number_of_coparison_stars):
            median = np.median(phot_table[:, 8+2 + j*7])
            std = np.std(phot_table[:, 8+2 + j*7])

            if np.sum(phot_table[:, 8 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1
            if np.sum(phot_table[:, 8+1 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1
            if np.sum(phot_table[:, 8+2 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1

            mask = (phot_table[:, 8+2 + j*7] < median - 5*std) | (phot_table[:, 8+2 + j*7] > median + 5*std)
            ax[j].scatter(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], alpha = 0.1, c='k', s= 5)
            x_, y_, e_ = lc_bin(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], 0.25/24)
            ax[j].scatter(x_, y_, alpha =1 , c='r', s= 10)

            ax[j].set_ylim(*np.percentile(phot_table[:, 8+2 + j*7]/phot_table[:, -5], [1,99]))
            ax[j].set_xticks([]); ax[j].set_yticks([])
            ax[j].set_ylabel(str(j+1), rotation=0) #

            if bad_star_mask[diff_ref_regions_idx,i,j] == 1: # These are systematicall shit nights
                ax[j].set_facecolor('xkcd:salmon')
                ax[j].set_facecolor((1.0, 0.47, 0.42)) 

            if bad_star_mask[diff_ref_regions_idx,i,j] == 2: # These are systematicall shit nights
                ax[j].set_facecolor('xkcd:crimson')
                ax[j].set_facecolor((1.0, 0.47, 0.42)) 
        if np.sum(bad_star_mask[diff_ref_regions_idx,i])==10 : 
            bad_star_mask[diff_ref_regions_idx,i] = np.zeros(number_of_coparison_stars, dtype = np.bool)
            bad_night[i] = True 

        print('\tAction {:} has {:} bad comparison star(s) in reference image {:}'.format(actions[i], int(bad_star_mask[diff_ref_regions_idx].sum(axis=1)[i]), diff_ref_regions_idx+1))
        
        plt.tight_layout()
        plt.savefig('Action_{:}_comparison_star_summary_first.png'.format(actions[i]))
        plt.close()


    # Now convert badstarmask into bool 
    bad_star_mask = np.array([~j.sum(axis=0).astype(np.bool) for j in  bad_star_mask],dtype = np.bool)

    for i in range(len(bad_star_mask)):
        print(np.sum(bad_star_mask[i]))
        if np.sum(bad_star_mask[i])==0: 
            bad_star_mask[i] = ~bad_star_mask[i]
    print('\tA total of {:} unique stars have been rejected'.format(np.sum(~bad_star_mask, axis=0).sum()))

    
    ###########################################################
    # Section 2
    # Now we have the bad comparison stars, we can start to 
    # go through each acion, sum the flux of the good 
    # good comparisons, and estimate the best aperture.
    ############################################################
    print('\nSearching for best aperture using best filtered comparison stars')
    best_apertures = np.empty(actions.shape[0])
    for i in range(actions.shape[0]):
        if bad_night[i]:
            print('Skipping action {:} since it looks like a shit night'.format(actions[i]))
        # First, we need to find which diff_ref_region with aperture belogs to
        for j in range(len(diff_ref_regions)):
            if (i >= diff_ref_regions[j][0]) and (i< diff_ref_regions[j][1]) : 
                diff_ref_regions_idx = j 
                break

        # Get the number of apertures
        apertures = [j.split('.')[1][4:] +'.' +  j.split('.')[2] for j in glob.glob('{:}.phot*'.format(actions[i]))]
        apertures.sort(key=float)

        print('\tProcessing action {:}'.format(actions[i]))
        # Now set up the figure to plot 
        fig1 = plt.figure(constrained_layout=False, figsize=(5,20))
        gs1 = gridspec.GridSpec(len(apertures), 2, figure=fig1)
        ax_count = 0 
        rms = np.empty(len(apertures))  

        for j in range(len(apertures)):

            # Load the photometry table
            phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], apertures[j]))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 

            # Modify the JD 
            phot_table[:, 0] = phot_table[:, 0] - int(np.min(phot_table[:, 0]))

            # 7 admin rows : JD-MID BJD-TDB-MID HJD-MID EXPTIME AIRMASS FWHM AGERRX AGERRY
            # 7 for the target :  X Y FLUX FLUXERR SKY SKYERR MAXPIX
            # 7 for the comps  :  X Y FLUX FLUXERR SKY SKYERR MAXPIX

            # Let's create a sigma clip mask just in case
            comparison_flux =  phot_table[:, (8+2)::7].T[:-1][:10][bad_star_mask[diff_ref_regions_idx]].T.sum(axis=1) / phot_table[:,-5]
            #ass.scatter(phot_table[:,0],phot_table[:, (8+2)::7].T[:-1][bad_star_mask].T.sum(axis=1), s=10)


            std = np.std(comparison_flux)
            sigma_clip_mask = (comparison_flux < comparison_flux - 10*std) | (comparison_flux > comparison_flux + 10*std) 

            
            # Before we fit a line, we want to do a bin of the nights lightcurve and look at the variance.
            # It might be that there is cloud which we might want to exclude
            # lets do a 10-minute bin
            N_cloud_thresh = 0.1
            time_bin, mag_bin, mag_bin_err = lc_bin(phot_table[:, 0], comparison_flux, 0.25/24)
            if (time_bin.shape[0] > 4): # incase the time axis is too short
                if np.max(mag_bin_err) > N_cloud_thresh:
                    noisy_data_times = time_bin[np.where(mag_bin_err > N_cloud_thresh)]
                    if noisy_data_times.shape[0] == 0 :
                        noisy_data_start = noisy_data_times[0] - (10./24./60)
                        noisy_data_end   = noisy_data_times[0] + (10./24./60)
                    else:
                        noisy_data_start = np.min(noisy_data_times) - (10./24./60)
                        noisy_data_end   = np.max(noisy_data_times) + (10./24./60)                    
                    sigma_clip_mask = sigma_clip_mask | ((phot_table[:, 0] > noisy_data_start) & (phot_table[:, 0] < noisy_data_end))
            

            # Let's check the sigma_clip_mask, if everything's masked then let's revert it 
            if phot_table[:, 0][~sigma_clip_mask].shape[0] == 0 : sigma_clip_mask = ~sigma_clip_mask # in case all is masked, don't use it

            # Let's calculate weights and do a polynomial fit
            weights = comparison_flux[~sigma_clip_mask]**-2*np.sum(comparison_flux[~sigma_clip_mask]**-2)
            model = np.poly1d(np.polyfit(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], 2, w = weights ))(phot_table[:, 0][~sigma_clip_mask])

            # Now calculate the RMS 
            rms[j] = np.std(comparison_flux[~sigma_clip_mask] / model)

            # Now let's plot
            ax = plt.subplot(gs1[j, 0])
            ax.scatter(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], alpha = 1, c='k', s= 5)
            ax.plot(phot_table[:, 0][~sigma_clip_mask], model, 'b--')
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_ylabel('{:} pix'.format(apertures[j]), fontsize=7)
            ax2 = ax.twinx()
            ax2.set_ylabel('{:,} ppm'.format(int(1e6*rms[j])))

            # Now verbose and set up for next one
            print('\t\t{:>8} {:>5} has an rms of {:>8,} ppm [{:>3.1f}% exclusion]'.format('Aperture', apertures[j], int(1e6*rms[j]), 100*(np.sum(mask)/mask.shape[0]) ))
    
        # Now plot RMS vs aperture on the LHS of graph (in ppm)
        ax = plt.subplot(gs1[:, 1])
        ax.plot(1e6*rms, apertures, 'b')
        ax.set_ylim(16,0)
        ax.set_xlabel('RMS [ppm]')
        ax.set_ylabel('Aperture size [pixel]')

        best_apertures[i] = round_of_rating(float(apertures[np.argmin(rms)]))
        print('\t\tBest aperture for action: {:} pixels'.format(best_apertures[i]))
        ax.axhline(best_apertures[i]-1, c='b', ls='--')
        plt.suptitle('Action {:}'.format(actions[i]))
        fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
        fig1.savefig('Action_{:}_aperture_rms.png'.format(actions[i]))
        plt.close(fig1)

    


    # Now we have the best apertures for each action, let's take a median
    print('\nSummary of apertures:')
    for i in range(len(actions)) : print('\taction {:>8} -> {:} pixel aperture'.format(actions[i], best_apertures[i]))
    best_aperture = round_of_rating(np.median(best_apertures.astype(float))) 
    best_aperture_string=''     
    for i in range(len(apertures)) : 
        if best_aperture == np.array(apertures).astype(float)[i] : best_aperture_string = apertures[i]      # <-- The string best aperture       
    print('\tMedian : {:} pixel aperture'.format(best_aperture_string))


    f = open('.aperture',"w+")
    f.write('{:}\n{:}'.format(best_aperture_string,actions[-1]))
    f.close()

    
    # Now plot the apertures VS action ID
    fig1 = plt.figure(figsize=(15,5))
    plt.scatter(range(len(best_apertures)), best_apertures, c='k', s=10)
    plt.axhline(best_aperture, ls='--', c='b')
    plt.xlabel('Action')
    plt.ylabel('Best aperture size [pix]')
    plt.gca().set_xticklabels(actions)
    fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig1.savefig('Actions_aperture_summary.png')
    plt.close()
