#!python

import numpy as np, os, sys, math
from bruce.binarystar import lc
import argparse
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import matplotlib 
matplotlib.rcParams.update({'font.size': 22})

def gaus(x,a,x0,sigma):
    return a*np.exp(-(x-x0)**2/(2*sigma**2))


description = '''A program to fit binary star observations elegantly. 
Use the -h flag to see all available options for the fit. For any questions, 
please email samuel.gill@warwick.ac.uk'''

# Argument parser
parser = argparse.ArgumentParser('ngtsfit', description=description)
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename of the binary star information') 

parser.add_argument('-a', 
                    '--saveplace',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=str,
					 default='.')


if __name__ == "__main__":
    # Parse the args
    args = parser.parse_args() 

    # Load the data 
    t, m, me = np.loadtxt(args.filename).T

    med = np.median(np.abs(np.diff(m, 0)))
    std = np.std(np.abs(np.diff(m, 0))) 
    diff_ = np.diff(m, 0)
    diff = np.abs(diff_) 



    peaks, _ = find_peaks(diff_, height=med + 8*std, distance=10)
    #print('Number of peaks : {:}'.format(len(peaks)))


    f, axs = plt.subplots(nrows = len(peaks)+2, ncols=1, figsize=(15,5*(len(peaks)+2)))

    axs[0].plot(t, diff, 'k')
    axs[0].axhline(med + 8*std, c='k', ls='--', alpha = 0.4) 

    axs[1].scatter(t, m, c='k', s=10)
    axs[1].set_xlabel('Time [BTJD]')
    axs[1].set_ylabel('T mag')
    axs[1].invert_yaxis()


    if len(peaks) > 0 :
        f = open('{:}/{:}_monofind_results.dat'.format(args.saveplace, args.filename.split('.')[0]), "w+")
        f.write('Peak, Epoch, Depth, Width')
        axs[0].plot(t[peaks], diff[peaks], "x")
        axs[1].plot(t[peaks], m[peaks], "x")

        for i in range(len(peaks))[:]:
            t_ = t[(t > (t[peaks[i]] - 0.5)) & (t < (t[peaks[i]] + 0.5))]
            m_ = m[(t > (t[peaks[i]] - 0.5)) & (t < (t[peaks[i]] + 0.5))]
            diff_ = diff[(t > (t[peaks[i]] - 0.5)) & (t < (t[peaks[i]] + 0.5))]

            # Gaussian fit the peak 
            try:
                popt,pcov = curve_fit(gaus,t_,diff_,p0=[diff[peaks[i]],t[peaks[i]],0.2])
            except : continue
            axs[0].plot(t_,gaus(t_,*popt),'r')  

            axs[0].text(popt[1], popt[0], 'Peak {:}\nDepth : {:.1f} mmag\nWidth : {:.1f} hrs\nT0 : {:.2f}'.format(i+1, popt[0]*1e3,popt[2]*48, popt[1]))
            f.write('\n{:},{:},{:},{:}'.format(i+1,  popt[1], popt[0]*1e3,popt[2]*48))

            axs[i+2].plot(t_,m_, 'ko-', label='Peak 1')
            leg = axs[i+2].legend(handlelength=0, handletextpad=0, fancybox=True)
            for item in leg.legendHandles : item.set_visible(False) 
            axs[i+2].set_xlim(t_[0] - 0.5, t_[-1]+0.5)
            axs[i+2].set_ylim(np.max(m_)+0.001, np.min(m_)-0.001)
            axs[i+2].set_ylabel('T mag')
            axs[i+2].set_xlabel('Time [BTJD]')

    if len(peaks) == 0 : plt.savefig('{:}/{:}_monofind_plot.png'.format(args.saveplace, args.filename.split('.')[0]))
    else               : plt.savefig('{:}/{:}_monofind_plots.png'.format(args.saveplace, args.filename.split('.')[0]))
    plt.close()



    plt.show()



