#!/usr/bin/env python
from __future__ import print_function, division
import sys
import errno
import warnings
import logging

import math as mt
import numpy as np
from glob import glob

try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    mpi_rank = comm.Get_rank()
    mpi_size = comm.Get_size()
    with_mpi = True
except ImportError:
    mpi_rank = 0
    mpi_size = 1
    with_mpi = False

from copy import copy   
from collections import namedtuple
from numpy import (any, array, ones_like, fromstring, tile, median,
                   zeros_like, exp, isfinite, argmax, argmin)
import numpy
from numpy.random import normal

from time import time, sleep
from os.path import join, exists, abspath, basename, isdir
from argparse import ArgumentParser

from k2sc.core import *
from k2sc.detrender import Detrender
from k2sc.kernels import kernels, BasicKernel, BasicKernelEP, QuasiPeriodicKernel, QuasiPeriodicKernelEP
from k2sc.k2io import select_reader, FITSWriter, SPLOXReader
from k2sc.cdpp import cdpp
from k2sc.de import DiffEvol
from k2sc.ls import fasper
from k2sc.utils import medsig, sigma_clip

warnings.resetwarnings()
warnings.filterwarnings('ignore', category=UserWarning, append=True)
warnings.filterwarnings('ignore', category=DeprecationWarning, append=True)

##TODO: Copy the FITS E0 header if a fits file is available
##TODO: Define the covariance matrix splits for every campaign

mpi_root = 0
logging.basicConfig(level=logging.DEBUG, format='%(levelname)s %(name)s: %(message)s')

default_splits = {3:[2154,2190], 4:[2240,2273], 5:[2344], 6:[2390,2428], 7:[2468.5,2515],
                  8: [2579,2598], 10: [2772,2800], 12: [2916,2949], 13: [2997,3032], 14: [3086.4,3123.65], 
                  15: [3170,3207.5], 16: [3297.5,3331], 17: [3367,3400], 18: [3425,3460]}

if hasattr(numpy, 'nanmedian'):
    from numpy import nanmedian
else:
    def nanmedian(a):
        return np.median(a[isfinite(a)])

def psearch(time, flux, min_p, max_p):
    freq,power,nout,jmax,prob = fasper(time, flux, 6, 0.5)
    period = 1/freq
    m = (period > min_p) & (period < max_p) 
    period, power = period[m], power[m]
    j = argmax(power)

    expy = mt.exp(-power[j])
    effm = 2*nout/6
    fap  = expy*effm

    if fap > 0.01:
        fap = 1.0-(1.0-expy)**effm
    
    return period[j], fap


def detrend(dataset):
    ## Setup the logger
    ## ----------------
    logger  = logging.getLogger('Worker %i'%mpi_rank)
    logfile = open('{:s}.{:03d}'.format(args.logfile, mpi_rank), mode='w')
    fh = logging.StreamHandler(logfile)
    fh.setFormatter(logging.Formatter('%(levelname)s %(name)s: %(message)s'))
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    fpath_out = join(args.save_dir, reader.fn_out_template.format(dataset.epic))
    logger.name = 'Worker {:d} <{:d}>'.format(mpi_rank, dataset.epic)

    np.seterrcall(lambda e,f: logger.info(e))
    np.seterr(invalid='ignore')

    ## Main variables
    ## --------------
    Result  = namedtuple('SCResult', 'detrender pv tr_time tr_position cdpp_r cdpp_t cdpp_c warn')
    results = []  # a list of Result tuples, one per aperture
    masks   = []  # a list of light curve masks, one per aperture 

    ## Initialise utility variables
    ## ----------------------------
    ds   = dataset
    info = logger.info
    error = logger.error

    ## Define the splits
    ## -----------------
    if args.splits is None and ds.campaign not in default_splits.keys():
        error('The campaign not known and no splits given.')
    elif args.splits is not None:
        splits = args.splits
        info('Using split values {:s} given from the command line'.format(str(splits)))
    else:
        splits = default_splits[ds.campaign]
        info('Using default splits {:s} for campaign {:d}'.format(str(splits), ds.campaign))

    ## Periodic signal masking
    ## -----------------------
    if args.p_mask_center and args.p_mask_period and args.p_mask_duration:
        ds.mask_periodic_signal(args.p_mask_center, args.p_mask_period, args.p_mask_duration)

    ## Initial outlier and period detection
    ## ------------------------------------
    ## We carry out an initial outlier and period detection using
    ## a default GP hyperparameter vector based on campaign 4 fits
    ## done using (almost) noninformative priors.

    for iset in range(ds.nsets):
        flux = ds.fluxes[iset]
        mask  = isfinite(flux) 
        mask &= ~(ds.mflags[iset] & M_PERIODIC).astype(np.bool)  # Apply the transit mask, if any
        mask &= ~(ds.quality & 2**20).astype(np.bool)            # Mask out the thruster firings
        inputs = np.transpose([ds.time, ds.x, ds.y])
        masks.append(mask)

        detrender = Detrender(flux, inputs, mask = mask, splits = splits,
                              kernel = BasicKernelEP(),
                              tr_nrandom = args.tr_nrandom,
                              tr_nblocks = args.tr_nblocks,
                              tr_bspan = args.tr_bspan)
    
        ttrend,ptrend = detrender.predict(detrender.kernel.pv0+1e-5, components=True)
        cflux = flux-ptrend+median(ptrend)-ttrend+median(ttrend)
        cflux /= nanmedian(cflux)

        ## Iterative sigma-clipping
        ## ------------------------
        info('Starting initial outlier detection')
        omask = mask & sigma_clip(cflux, max_iter=10, max_sigma=5, mexc=mask)
        ofrac = (~omask).sum() / omask.size
        if ofrac < 0.25:
            mask &= omask
            info('  Flagged %i (%4.1f%%) outliers.', (~omask).sum(), 100*ofrac)
        else:
            info('  Found %i (%4.1f%%) outliers. Not flagging.', (~omask).sum(), 100*ofrac)

        ## Lomb-Scargle period search
        ## --------------------------
        if ofrac < 0.9:
            info('Starting Lomb-Scargle period search')
            nflux = flux - ptrend + nanmedian(ptrend)
            ntime = ds.time - ds.time.mean()
            pflux = np.poly1d(np.polyfit(ntime[mask], nflux[mask], 9))(ntime)

            period, fap = psearch(ds.time[mask], (nflux-pflux)[mask], args.ls_min_period, args.ls_max_period)
        
            if fap < 1e-50:
                ds.is_periodic = True
                ds.ls_fap    = fap
                ds.ls_period = period
        else:
            info('Too many outliers, skipping the Lomb-Scargle period search')

    ## Kernel selection
    ## ----------------
    if args.kernel:
        info('Overriding automatic kernel selection, using %s kernel as given in the command line', args.kernel)
        if 'periodic' in args.kernel and not args.kernel_period:
            logger.critical('Need to give period (--kernel-period) if overriding automatic kernel detection with a periodic kernel. Quitting.')
            exit(1)
        kernel = kernels[args.kernel](period=args.kernel_period)
    else:
        info('  Using %s position kernel', args.default_position_kernel)
        if ds.is_periodic:
            info('  Found periodicity p = {:7.2f} (fap {:7.4e} < 1e-50), will use a quasiperiodic kernel'.format(ds.ls_period, ds.ls_fap))
        else:
            info('  No strong periodicity found, using a basic kernel')

        if args.default_position_kernel.lower() == 'sqrexp':
            kernel = QuasiPeriodicKernel(period=ds.ls_period)   if ds.is_periodic else BasicKernel() 
        else:
            kernel = QuasiPeriodicKernelEP(period=ds.ls_period) if ds.is_periodic else BasicKernelEP()


    ## Detrending
    ## ----------
    for iset in range(ds.nsets):
        if ds.nsets > 1:
            logger.name = 'Worker {:d} <{:d}-{:d}>'.format(mpi_rank, dataset.epic, iset+1)
        np.random.seed(args.seed)
        tstart = time()
        
        inputs = np.transpose([ds.time,ds.x,ds.y])
        detrender = Detrender(ds.fluxes[iset], inputs, mask=masks[iset],
                              splits=splits, kernel=kernel, tr_nrandom=args.tr_nrandom,
                              tr_nblocks=args.tr_nblocks, tr_bspan=args.tr_bspan)
        de = DiffEvol(detrender.neglnposterior, kernel.bounds, args.de_npop)

        ## Period population generation
        ## ----------------------------
        if isinstance(kernel, QuasiPeriodicKernel):
            de._population[:,2] = np.clip(normal(kernel.period, 0.1*kernel.period, size=de.n_pop),
                                          args.ls_min_period, args.ls_max_period)
        ## Hyperparameter optimisation
        ## ---------------------------
        if np.isfinite(ds.fluxes[iset]).sum() >= 100:
            ## Global hyperparameter optimisation
            ## ----------------------------------
            info('Starting global hyperparameter optimisation using DE')
            tstart_de = time()
            for i,r in enumerate(de(args.de_niter)):
                info('  DE iteration %3i -ln(L) %4.1f', i, de.minimum_value)
                tcur_de = time()
                if ((de._fitness.ptp() < 3) or (tcur_de - tstart_de > args.de_max_time)) and (i>2):
                    break
            info('  DE finished in %i seconds', tcur_de-tstart_de)
            info('  DE minimum found at: %s', np.array_str(de.minimum_location, precision=3, max_line_width=250))
            info('  DE -ln(L) %4.1f', de.minimum_value)

            ## Local hyperparameter optimisation
            ## ---------------------------------
            info('Starting local hyperparameter optimisation')
            try:
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
                    pv, warn = detrender.train(de.minimum_location)
            except ValueError as e:
                logger.error('Local optimiser failed, %s', e)
                logger.error('Skipping the file')
                return
            info('  Local minimum found at: %s', np.array_str(pv, precision=3))

            ## Trend computation
            ## -----------------
            (mt,tt),(mp,tp) = map(lambda a: (nanmedian(a), a-nanmedian(a)), detrender.predict(pv, components=True))

            ## Iterative sigma-clipping
            ## ------------------------
            info('Starting final outlier detection')
            flux = detrender.data.unmasked_flux
            cflux = flux-tp-tt
            cflux /= nanmedian(cflux)

            mper = ~(ds.mflags[iset] & M_PERIODIC).astype(np.bool)  # Apply the transit mask, if any
            mthf = ~(ds.quality & 2**20).astype(np.bool)            # Mask out the thruster firings
            minf = np.isfinite(cflux)

            mlow, mhigh = sigma_clip(cflux, max_iter = 10, max_sigma = 5, separate_masks = True, mexc = mper&mthf)
            ds.mflags[iset][~minf]  |= M_NOTFINITE
            ds.mflags[iset][~mhigh]  |= M_OUTLIER_U
            ds.mflags[iset][~mlow]   |= M_OUTLIER_D

            info('  %5i too high', (~mhigh).sum())
            info('  %5i too low',  (~mlow).sum())
            info('  %5i not finite', (~minf).sum())

            ## Detrending and CDPP computation
            ## -------------------------------
            info('Computing time and position trends')
            dd = detrender.data
            cdpp_r = cdpp(dd.masked_time,   dd.masked_flux)
            cdpp_t = cdpp(dd.unmasked_time, dd.unmasked_flux-tp,    exclude=~dd.mask)
            cdpp_c = cdpp(dd.unmasked_time, dd.unmasked_flux-tp-tt, exclude=~dd.mask)
        else:
            info('Skipping dataset %i, not enough finite datapoints')
            cdpp_r, cdpp_t, cdpp_c, warn = -1, -1, -1, -1
            mt, mp = np.nan, np.nan
            tt = np.full_like(detrender.data.unmasked_flux, np.nan)
            tp = np.full_like(detrender.data.unmasked_flux, np.nan)
            pv = np.full(kernel.npar, np.nan)
            detrender.tr_pv = pv.copy()            

        results.append(Result(detrender, pv, tt+mt, tp+mp, cdpp_r, cdpp_t, cdpp_c, warn))
        info('  CDPP - raw - %6.3f', cdpp_r)
        info('  CDPP - position component removed - %6.3f', cdpp_t)
        info('  CDPP - full reduction - %6.3f', cdpp_c)
        info('Detrending time %6.3f', time()-tstart)

    FITSWriter.write(fpath_out, splits, ds, results)
    info('Finished')
    fh.flush()
    logger.removeHandler(fh)
    fh.close()
    logfile.close()


if __name__ == '__main__':
    ap = ArgumentParser(description='K2SC: K2 systematics correction using Gaussian processes')
    gts = ap.add_argument_group('Training set options')
    gps = ap.add_argument_group('Period search', description='Options to control the initial Lomb-Scargle period search')
    god = ap.add_argument_group('Outlier detection')
    gde = ap.add_argument_group('Global optimisation', description='Options to control the global hyperparameter optimisation')
    ap.add_argument('files', type=str, nargs='*', help='Input light curve file name(s) or the input directory name.')
    ap.add_argument('-c', '--campaign', metavar='C', type=int, help='Campaign number', default=None)
    ap.add_argument('--splits', default=None, type=lambda s:fromstring(s.strip('[]'), sep=','), help='List of time values for kernel splits')
    ap.add_argument('--quiet', action='store_true', default=False, help='suppress messages')
    ap.add_argument('--save-dir', default='.', help='The directory to save the output file in')
    ap.add_argument('--start-i', default=0, type=int)
    ap.add_argument('--end-i', default=None, type=int)
    ap.add_argument('--seed', default=0, type=int)
    ap.add_argument('--logfile', default='', type=str)
    ap.add_argument('--flux-type', default='pdc', type=str)
    ap.add_argument('--default-position-kernel', choices=['SqrExp','Exp'], default='SqrExp')
    ap.add_argument('--kernel', choices=kernels.keys(), default=None, help='Kernel to use (overrides the automatic kernel selection)')
    ap.add_argument('--kernel-period', type=float, default=None, help='Period for the (quasi)periodic kernels (overrides the period detection)')
    ap.add_argument('--p-mask-center', type=float, default=None, help='Period mask zero epoch')
    ap.add_argument('--p-mask-period', type=float, default=None, help='Period mask period')
    ap.add_argument('--p-mask-duration', type=float, default=None, help='Period mask event duration')
    gts.add_argument('--tr-nrandom', default=400, type=int, help='Number of random samples')
    gts.add_argument('--tr-nblocks', default=6, type=int, help='Number of sample blocks')
    gts.add_argument('--tr-bspan', default=50, type=int, help='Span of a single block')
    gde.add_argument('--de-npop', default=100, type=int, help='Size of the differential evolution parameter vector population')
    gde.add_argument('--de-niter', default=150, type=int, help='Number of differential evolution iterations')
    gde.add_argument('--de-max-time', default=300, type=float, help='Maximum time used for differential evolution')
    gps.add_argument('--ls-max-fap', default=-50, type=float, help='Maximum Lomb-Scargle log10(false alarm) threshold to use the periodic kernel')
    gps.add_argument('--ls-min-period', default=0.05, type=float, help='Minimum period to search for')
    gps.add_argument('--ls-max-period', default=25, type=float, help='Maximum period to search for')
    god.add_argument('--outlier-sigma', default=5, type=float)
    god.add_argument('--outlier-mwidth', default=25, type=int)
    args = ap.parse_args()

    ## Logging
    ##
    if mpi_rank == 0:
        logger = logging.getLogger('Master')
        if args.logfile:
            logfile = open(args.logfile, mode='w')
            fh = logging.StreamHandler(logfile)
            fh.setFormatter(logging.Formatter('%(levelname)s %(name)s: %(message)s'))
            fh.setLevel(logging.DEBUG)
            logger.addHandler(fh)

    if not exists(args.save_dir):
        logger.error("Error: the save directory {:s} doesn't exist".format(args.save_dir), file=sys.stderr)
        exit(errno.ENOENT)

    ## Test whether we're dealing with a directory or a list of files and select
    ## the data reader (NOTE: We don't allow mixed input types per run)
    ##
    ## DIRECTORY BRANCH
    ## ----------------
    if len(args.files) == 1 and isdir(args.files[0]):
        singlefile = False 
        files = map(abspath, sorted(glob(join(args.files[0],'ktwo*.fits'))))
        reader = select_reader(files[0])

        if reader is None:
            logger.critical("Unrecognized input file type for file {:s}".format(args.files[0]))
            exit()
        infile = files[args.start_i]
        sid = 0
        all_items = files[args.start_i:args.end_i]

    ## LIST OF FILES BRANCH
    ## --------------------
    else:
        reader = select_reader(args.files[0])
        if reader is None:
            logger.critical("Unrecognized input file type for file {:s}".format(args.files[0]))
            exit()

        ## Check if we're dealing with a single input file with many stars
        if len(args.files) == 1 and issubclass(reader, SPLOXReader):
            singlefile = True
            infile = args.files[0]
            sid = args.start_i
            all_items = range(args.start_i, min(args.end_i or 1e8, reader.nstars(args.files[0])))
        else:
            singlefile = False
            infile = args.files[args.start_i]
            sid = 0
            all_items = args.files[args.start_i:args.end_i]

    items = copy(all_items)
    n_items = len(items)

    if mpi_rank == 0:
        if (not with_mpi) or (mpi_size==1):
            logger.info("Detrending {:d} light curves without MPI".format(n_items))
        else:
            logger.info("Detrending {:d} light curves using {:d} worker nodes".format(n_items, mpi_size-1))
        logger.info('')
        logger.info('Saving the results to %s', args.save_dir)
        logger.info('')
        logger.info('Differential evolution parameters')
        logger.info('  Population size: {:3d}'.format(args.de_npop))
        logger.info('  Number of iterations: {:3d}'.format(args.de_niter))
        logger.info('  Maximum DE time: {:6.2f} seconds'.format(args.de_max_time))
        logger.info('')

    ## Without MPI or running with a single node
    ## =========================================
    if (not with_mpi) or (mpi_size==1):
        for item in items:
            if singlefile:
                sid = item
            else:
                infile = item
                if not exists(infile):
                    logger.warning("The input file {:s} doesn't exists, skipping the file".format(infile))
                    continue
            dataset = reader.read(infile, sid=sid, type=args.flux_type)
            detrend(dataset)
                                

    ## With MPI
    ## ========
    else:
        ## Master node
        ## -----------
        if mpi_rank == 0:
            free_workers = range(1,mpi_size)
            active_workers = []
            n_finished_items = 0

            while items or active_workers:
                ## Send an item
                while items and free_workers:
                    w = free_workers.pop()

                    item = items.pop()

                    if singlefile:
                        sid = item
                        logger.info("Processing star %i", sid)
                    else:
                        infile = item
                        logger.info("Processing file %s",abspath(infile))
                        if not exists(infile):
                            logger.warning("The input file {:s} doesn't exists, skipping the file".format(infile))
                            continue
                        
                    dataset = reader.read(infile, sid=sid, type=args.flux_type, campaign=args.campaign)
                    comm.send(dataset, dest=w, tag=0)
                    active_workers.append(w)

                ## Receive the results
                for w in active_workers:
                    if comm.Iprobe(w, 2):
                        res = comm.recv(source=w, tag=2)
                        free_workers.append(w)
                        active_workers.remove(w)
                        n_finished_items += 1

                        if args.logfile:
                            logfile_w = open('{:s}.{:03d}'.format(args.logfile, w), 'r')
                            logfile.write(logfile_w.read())
                            logfile_w.close()
                        logger.info("Finished {:3d} of {:3d} light curves".format(n_finished_items,n_items))

            for w in free_workers:
                comm.send(-1, dest=w, tag=0)

        ## Worker node
        ## -----------
        else:
            while True:
                dataset = comm.recv(source=mpi_root, tag=0)
                if infile == -1:
                    break

                detrend(dataset)
                comm.send(dataset.epic, dest=mpi_root, tag=2)    


        
