#!python
import warnings
import logging
import re

import pandas as pd
import astropy.io.fits as pf

from argparse import ArgumentParser
from datetime import datetime
from glob import glob
from os.path import join, exists
from os import makedirs

from numpy import nanmean

logging.basicConfig(level=logging.DEBUG, format='%(levelname)s : %(message)s')

warnings.resetwarnings()
warnings.filterwarnings('ignore', category=UserWarning, append=True)

re_epic = re.compile('EPIC_([0-9]+)_')
fntemplate = 'hlsp_k2sc_k2_llc_{epic:9d}-c{c:02d}_kepler_v{v:1d}_lc.fits'

def create_series(path, name):
    files = sorted(glob(path))
    epics = map(int, re_epic.findall(''.join(files)))
    return pd.Series(files, index=epics, name=name)

if __name__ == '__main__':
    ap = ArgumentParser()
    ap.add_argument('pdc_path', type=str)
    ap.add_argument('sap_path', type=str)
    ap.add_argument('out_path', type=str)
    ap.add_argument('--release-version', type=int, default=1)
    ap.add_argument('--logfile', type=str, default=None)
    args = ap.parse_args()

    spdc = create_series(join(args.pdc_path,'EPIC*fits'), 'pdc')
    ssap = create_series(join(args.sap_path,'EPIC*fits'), 'sap')
    df = pd.concat([spdc,ssap], axis=1).dropna()

    logger = logging.getLogger('Master')
    if args.logfile:
        logfile = open(args.logfile, mode='w')
        fh = logging.StreamHandler(logfile)
        fh.setFormatter(logging.Formatter('%(levelname)s %(name)s: %(message)s'))
        fh.setLevel(logging.DEBUG)
        logger.addHandler(fh)

    for i,(epic,r) in enumerate(df.iterrows()):
        with pf.open(r.pdc) as fpdc, pf.open(r.sap) as fsap:
            epri,epdc,esap = fpdc[0],fpdc[1], fsap[1]
            epri.header['origin'] = epdc.header['origin']
            epri.header['creator'] = epdc.header['program']
            epri.header['date'] = t = datetime.today().strftime('%Y-%m-%dT%H:%M:%S')

            for k in 'procver filever timversn checksum object'.split():
                if k in epri.header.keys():
                    epri.header.remove(k)
                    
            for e in (epdc,esap):
                e.header.remove('object')
                e.data.flux[:] = (e.data.flux 
                                    - e.data.trposi + nanmean(e.data.trposi)
                                    - e.data.trtime + nanmean(e.data.trtime))
            hdu_list = pf.HDUList([epri, epdc, esap])
            hdu_list[1].header['EXTNAME'] = 'PDC'
            hdu_list[2].header['EXTNAME'] = 'SAP'

            cdirname = 'c{:02d}'.format(epri.header['campaign'])
            fdirname = '{:s}00000'.format(str(epic)[:4])
            fpath = join(args.out_path,cdirname,fdirname)
            fname = fntemplate.format(epic=epic, c=epri.header['campaign'], v=args.release_version)

            if not exists(fpath):
                makedirs(fpath)        
            hdu_list.writeto(join(fpath,fname), clobber=True)
        logger.info('Finished %05i/%05i -- campaign %i EPIC %9i',i+1, len(df), epri.header['campaign'], epic)
