#!/usr/bin/env python
import math as mt
import numpy as np
import astropy.io.fits as pf
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as pl

from argparse import ArgumentParser
from os.path import join, abspath, basename
from mpi4py import MPI
from numpy import arange, zeros, nanmax, nanmin, nanmedian
from numpy.random import uniform
from time import sleep
from IPython.display import display
from glob import glob
from copy import copy

from glob import glob
from os.path import join, basename
from matplotlib.backends.backend_pdf import PdfPages

sb.set_style('white')
sb.set_context('paper')
pl.rc('figure', figsize=(14,5), dpi=100)

c_ob = '#002147'
c_bo = '#BF5700'

def normalise(a):
    return a/nanmedian(a)

def plot_lc(ax, ap, data):
    time = data['time']
    mask = data['mask_ol_%d' % ap].astype(np.bool)
    fraw = normalise(data['flux_%d'%ap])
    ctim = normalise(data['trend_t_%d'%ap])
    cpos = normalise(data['trend_p_%d'%ap])
    fred = normalise(fraw-cpos)

    shift = 1.5 * nanmax(fred - fraw)
    pshift = max(ctim[mask]-2*shift - cpos[mask])
    pshift = (ctim[mask]-2*shift).min() - cpos[mask].max()
    
    ax.plot(time[mask], fraw[mask])
    ax.plot(time[mask], fred[mask] - shift)
    
    pl.setp(ax, ylim=(0.99*nanmin(fred[mask]) - shift,1.01*nanmax(fraw[mask])))

def create_page():
    fig = pl.figure(figsize=(8.3,11.7))
    fig.text(0.04, 0.97, 'EPIC {:9d}'.format(epic), va='top', size=24, color='w', weight='bold')
    ax = fig.add_axes([0,0,1,1])
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.set_zorder(-1000)
    ax.add_patch(pl.Rectangle((0,0.92), 1, 0.08, fill=True, color=c_ob))
    ax.add_patch(pl.Rectangle((0,0), 1, 0.02, fill=True, color=c_ob))
    return fig, ax
    

if __name__ == '__main__':
    ap = ArgumentParser()
    ap.add_argument('file_name', type=str)
    ap.add_argument('--save-dir', default='.', type=str)
    ap.add_argument('--show', default=False, action='store_true')
    args = ap.parse_args()

    fname = args.file_name
    epic = int(basename(fname).split('_')[1])

    data = pf.getdata(fname, 1)
    
    #colnames_bs  = 'object cdpp2r cdpp2c cdpp3r cdpp3c ap2_warn ap3_warn'.split()
    #colnames_hp  = pf.getval(fname, 'ker_pars', ext=1).split()
    #colnames_hp2 = map(lambda s:'ap2_'+s, colnames_hp)
    #colnames_hp3 = map(lambda s:'ap3_'+s, colnames_hp)

    #mask_1 = data.omask_1.astype(np.bool)
    #mmag_1 = 23 - 2.5*mt.log10(np.median(data.flux_r_1[mask_1]))
    #mask_2 = data.omask_2.astype(np.bool)
    #mmag_2 = 23 - 2.5*mt.log10(np.median(data.flux_r_2[mask_2]))

    #pars = []
    #hd = pf.getheader(fname,1)
    #pars.append(map(hd.get, colnames_bs))
    #pars[-1].extend(np.fromstring(hd.get('ker_hps2').strip('[]'), sep=' '))
    #pars[-1].extend(np.fromstring(hd.get('ker_hps3').strip('[]'), sep=' '))
    #df = pd.DataFrame(pars, columns=colnames_bs)#+colnames_hp2+colnames_hp3)

    with PdfPages(join(args.save_dir,'EPIC_{:9d}.pdf'.format(epic))) as pdf:
        fig1, ax = create_page()
        gs = pl.GridSpec(3,4, height_ratios=[0.4,1,1], bottom=0.065, top=0.91, 
                         left=0.09, right=0.99, hspace=0.02, wspace=0.01)
        
        ## First page
        ## ----------
        ax0 = fig1.add_subplot(gs[0,:])
        ax2 = fig1.add_subplot(gs[1,:])
        ax3 = fig1.add_subplot(gs[2,:])

        plot_lc(ax2, 1, data)
        # plot_lc(ax3, 2, data)

        # ax2.text(0.01,0.94, 'A2', size=13, weight='bold',transform=ax2.transAxes)
        # ax3.text(0.01,0.94, 'A3', size=13, weight='bold',transform=ax3.transAxes)
        
        # ax0.text(0.05, 0.50, 'A2', size=18, weight='bold')
        # ax0.text(0.15,0.65,  'Fit ok: {:s}'.format(str(not df.irow(0).ap2_warn)), size=13)
        # ax0.text(0.15,0.20, r'23 - 2.5 log$_{10}$ <flux>: '+'{:4.2f}'.format(mmag_1), size=13)
        # ax0.text(0.15,0.50, 'CDPP raw: {:4.0f} ppm'.format(df.irow(0).cdpp2r), size=13)
        # ax0.text(0.15,0.35, 'CDPP red: {:4.0f} ppm'.format(df.irow(0).cdpp2c), size=13)
        
        # ax0.text(0.55, 0.50, 'A3', size=18, weight='bold')
        # ax0.text(0.65, 0.65,  'Fit ok: {:s}'.format(str(not df.irow(0).ap3_warn)), size=13)
        # ax0.text(0.65,0.20, r'23 - 2.5 log$_{10}$ <flux>: '+'{:4.2f}'.format(mmag_2), size=13)
        # ax0.text(0.65,0.50, 'CDPP raw: {:4.0f} ppm'.format(df.irow(0).cdpp3r), size=13)
        # ax0.text(0.65,0.35, 'CDPP red: {:4.0f} ppm'.format(df.irow(0).cdpp3c), size=13)

        pl.setp(ax0, frame_on=False, xticks=[], yticks=[])
        a2l, a3l = ax2.get_ylim(), ax3.get_ylim()
        pl.setp([ax2,ax3], ylim=(min(a2l[0], a3l[0]), max(a2l[1], a3l[1])))
        pl.setp([ax2,ax3], xlim=data['time'][[0,-1]], ylabel='Normalised flux')
        pl.setp(ax2.get_xticklabels(), visible=False)
        pl.setp(ax3, xlabel='Date')
        pdf.savefig(fig1)
        
        ## Second page
        ## -----------
        # gs2 = pl.GridSpec(7,4, bottom=0.065, top=0.91, left=0.11, right=0.99, hspace=0.02, wspace=0.01)
        # for ap in [1,2]:
        #     fig2, ax = create_page()
        #     axes = [fig2.add_subplot(gs2[i,:]) for i in range(7)]
            
        #     time = data['time']
        #     mask = data['omask_%d' % ap].astype(np.bool)
        #     fraw = normalise(data['flux_r_%d'%ap])
        #     fred = normalise(data['flux_c_%d'%ap])
        #     ctim = normalise(data['corr_t_%d'%ap])
        #     cpos = normalise(data['corr_p_%d'%ap])
            
        #     axes[0].plot(time,fraw, 'k', alpha=0.5, lw=1)
        #     axes[0].plot(time[mask],fraw[mask], lw=1)
        #     axes[1].plot(time[mask],fred[mask], lw=1)
        #     axes[2].plot(time[mask],fred[mask]-ctim[mask],  '.', lw=1,)
        #     axes[3].plot(time[mask],ctim[mask], lw=1)
        #     axes[4].plot(time[mask],cpos[mask], lw=1)
        #     axes[5].plot(time[mask],data['x'][mask], lw=1)
        #     axes[6].plot(time[mask],data['y'][mask], lw=1)
            
        #     [pl.setp(a, ylabel=yl) 
        #      for a,yl in zip(axes, 'Raw flux, Raw - XY, Raw - XY - Time, Time component, XY component, X, Y'.split(', '))]
        #     pl.setp(axes[0], ylim=(0.995*fraw[mask].min(),1.005*fraw[mask].max()))
        #     pl.setp(axes, xlim=time[[0,-1]])
        #     pl.setp(axes[1:5], ylim=axes[0].get_ylim())
        #     pl.setp(axes[2], ylim=np.array([-5,5])*(fred[mask]-ctim[mask]).std())
        #     pl.setp(axes[4], xlabel='Date')
        #     pdf.savefig(fig2)

    if args.show:
        pl.show()
