#!/home/sam/anaconda3/bin/python

import numpy as np, os, sys, math
from bruce.binarystar import lc
import argparse
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def gaus(x,a,x0,sigma):
    return a*np.exp(-(x-x0)**2/(2*sigma**2))


description = '''A program to fit binary star observations elegantly. 
Use the -h flag to see all available options for the fit. For any questions, 
please email samuel.gill@warwick.ac.uk'''

# Argument parser
parser = argparse.ArgumentParser('ngtsfit', description=description)
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename of the binary star information') 

parser.add_argument('-a', 
                    '--saveplace',
                     help='The transit epoch in arbritraty time units consisting with the input file.', type=str,
					 default='.')


if __name__ == "__main__":
    # Parse the args
    args = parser.parse_args() 

    # Load the data 
    t, m, me = np.loadtxt(args.filename).T

    med = np.median(np.abs(np.diff(m, 0)))
    std = np.std(np.abs(np.diff(m, 0))) 
    diff = np.abs(np.diff(m, 0)) 

    f = plt.figure(figsize=(15,5))
    plt.plot(t, diff)
    plt.axhline(med + 5*std, c='k', ls='--', alpha = 0.4) 

    peaks, _ = find_peaks(diff, height=med + 5*std, distance=10)
    print('Number of peaks : {:}'.format(len(peaks)))

    if len(peaks) > 0 :
        f = open('{:}/{:}_monofind_results.dat'.format(args.saveplace, args.filename.split('.')[0]), "w+")
        f.write('Peak, Epoch, Depth, Width')
        plt.plot(t[peaks], diff[peaks], "x")
        for i in range(len(peaks))[:5]:
            t_ = t[(t > (t[peaks[i]] - 0.5)) & (t < (t[peaks[i]] + 0.5))]
            diff_ = diff[(t > (t[peaks[i]] - 0.5)) & (t < (t[peaks[i]] + 0.5))]

            # Gaussian fit the peak 
            popt,pcov = curve_fit(gaus,t_,diff_,p0=[diff[peaks[i]],t[peaks[i]],0.2])
            plt.plot(t_,gaus(t_,*popt),'r')  

            plt.text(popt[1], popt[0], 'Peak {:}\nDepth : {:.1f} mmag\nWidth : {:.1f} hrs\nT0 : {:.2f}'.format(i+1, popt[0]*1e3,popt[2]*48, popt[1]))
            f.write('\n{:},{:},{:},{:}'.format(i+1,  popt[1], popt[0]*1e3,popt[2]*48))

    f.close()
    plt.savefig('{:}/{:}_monofind_plot.png'.format(args.saveplace, args.filename.split('.')[0]))
    plt.close()



    plt.show()



