#!/usr/bin/env python
__author__ = 'Matias Carrasco Kind'
from numpy import *
import copy
import random as rn
import os, sys
import pyfits as pf

try:
    from ml_codes import *
    from utils import *
except:
    from mlz.ml_codes import *
    from mlz.utils import *
SF90 = True
try:
    import ml_codes.somF
except:
    SF90 = False
if not SF90:
    try:
        import somF

        SF90 = True
    except:
        pass
sys.setrecursionlimit(8000)
try:
    from mpi4py import MPI

    PLL = 'MPI'
except ImportError:
    PLL = 'SERIAL'

if PLL == 'MPI':
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()
else:
    size = 1
    rank = 0

if rank == 0:
    utils_mlz.print_welcome()
    clock_all = utils_mlz.Stopwatch()
Nproc = size #number of processors
Narg = len(sys.argv)
if Narg != 2:
    if rank == 0: utils_mlz.usage()
    if PLL == 'MPI': MPI.Finalize()
    sys.exit(0)


def join_oob(rank, rawA, PLL):
    if rank == 0:
        Oraw = copy.deepcopy(rawA)
        for srank in xrange(1, Nproc):
            if PLL == 'MPI':
                rawA = comm.recv(source=srank, tag=srank * 2)
                Oraw += rawA
    else:
        if PLL == 'MPI': comm.send(rawA, dest=0, tag=2 * rank)
    if PLL == 'MPI': comm.Barrier()
    if rank == 0:
        return Oraw
    else:
        return 0.


FilePars = sys.argv[1] #inputs file name
#READ PARAMETERS
verbo = False
if rank == 0: verbo = True
Pars_in = utils_mlz.read_dt_pars(FilePars, verbose=verbo)
if Pars_in.varimportance == 'yes': Pars_in.ooberror == 'yes'
if Pars_in.checkonly == 'yes':
    if rank == 0:
        utils_mlz.printpz()
        utils_mlz.printpz('********************************')
        utils_mlz.printpz('* Check Mode only for testing  *')
        utils_mlz.printpz('********************************')
        utils_mlz.printpz()

if rank == 0:
    utils_mlz.printpz("Starting with ", size, " processors")
    utils_mlz.printpz()
    local_clock = utils_mlz.Stopwatch('no')

Pmode = Pars_in.predictionmode
if Pmode == 'TPZ_C': Pars_in.predictionclass = 'Class'

Cmode = Pars_in.predictionclass
#READ TRAIN AND TEST FILES

if Pars_in.dotrain == 'yes':
    Train = data.catalog(Pars_in, cat_type='train', rank=rank)

    if Pars_in.nrandom > 1:
        if rank == 0: Train.make_random(ntimes=int(Pars_in.nrandom))
    if PLL == 'MPI': comm.Barrier()
    if Pars_in.nrandom > 1: Train.load_random()
    if PLL == 'MPI': comm.Barrier()  ##<>##

    if rank == 0:
        utils_mlz.printpz('-> NUMBER OF GALAXIES IN TRAIN CATALOG : ', len(Train.cat))

    if rank == 0:
        utils_mlz.print_mode(Pmode)
        utils_mlz.print_mode(Cmode)
        if not SF90 and Pmode == 'SOM':
            utils_mlz.printpz()
            utils_mlz.printpz('******************************************************')
            utils_mlz.printpz('* Fortran module somF not found, using python module *')
            utils_mlz.printpz('*   try: f2py -c -m somF som.f90 to compile it       *')
            utils_mlz.printpz('******************************************************')
        utils_mlz.printpz()

    if Pars_in.importancefile == 'none':
        Pars_in.importance = ones(len(Pars_in.att))
        Pars_in.importance_all = ones(len(Pars_in.att))
    else:
        Pars_in.importance = loadtxt(Pars_in.importancefile)
        Pars_in.importance_all = loadtxt(Pars_in.importancefile)

    #TRAIN FIRST
    ntot = int(Pars_in.nrandom * Pars_in.ntrees)
    if rank == 0: utils_mlz.printpz('Total trees (maps) : ', str(ntot))
    s0, s1 = utils_mlz.get_limits(ntot, Nproc, rank)
    if rank == 0:
        for i in xrange(Nproc):
            Xs_0, Xs_1 = utils_mlz.get_limits(ntot, Nproc, i)
            utils_mlz.printpz(Xs_0, ' ', Xs_1, ' -------------> to core ', i)

    zfine, zfine2, resz, resz2, wzin = analysis.get_zbins(Pars_in)
    zfine2 = zfine2[wzin]
    train_nobj = Train.nobj

    if Pars_in.ooberror == 'yes':
        Train.get_XY()
        train_yvals = Train.Y
        if Cmode == 'Reg':
            OP0raw = zeros((Train.nobj, len(zfine) - 1))
            Train_S = analysis.GetPz_short(Pars_in)
        if Cmode == 'Class':
            OS1 = zeros((Train.nobj, 3))

    if Pars_in.varimportance == 'yes':
        if Cmode == 'Reg':
            OP0rawV = zeros((Train.nobj, len(zfine) - 1, len(Pars_in.att)))
            Train_SV = analysis.GetPz_short(Pars_in)
        if Cmode == 'Class':
            OS1V = zeros((Train.nobj, 3, len(Pars_in.att)))

    for kss in xrange(s0, s1):
        if Pars_in.nrandom > 1:
            ir = kss / int(Pars_in.ntrees)
            if ir != 0: Train.newcat(ir)
        if Pmode == 'SOM': DD = Train.sample_dim(int(Pars_in.natt))
        if Pmode == 'TPZ': DD = 'all'
        if Pmode == 'TPZ_C': DD = 'all'
        if Pars_in.ooberror == 'yes': Train.oob_data_cat()

        Train.get_XY(bootstrap='yes', curr_at=DD)
        if DD != 'all':
            impp = []
            for jk, ik in enumerate(Pars_in.att):
                if DD.has_key(ik): impp.append(Pars_in.importance_all[jk])
            Pars_in.importance = array(impp)

        if Pmode == 'TPZ':
            T = TPZ.Rtree(Train.X, Train.Y, forest='yes', minleaf=int(Pars_in.minleaf), mstar=int(Pars_in.natt),
                          dict_dim=DD)
            T.save_tree(kss, fileout=Train.Pars.treefilename, path=Train.Pars.path_output_trees)
        if Pmode == 'TPZ_C':
            YC = arange(int(Pars_in.minz), int(Pars_in.maxz) + 1, dtype='int')
            T = TPZ.Ctree(Train.X, Train.Y, forest='yes', minleaf=int(Pars_in.minleaf), mstar=int(Pars_in.natt),
                          dict_dim=DD,
                          impurity=Pars_in.impurityindex, nclass=YC)
            T.save_tree(kss, fileout=Train.Pars.treefilename, path=Train.Pars.path_output_trees)
        if Pmode == 'SOM':
            aps = Pars_in.alphastart
            ape = Pars_in.alphaend
            T = SOMZ.SelfMap(Train.X, Train.Y, Ntop=int(Pars_in.ntop), topology=Pars_in.topology,
                             som_type=Pars_in.somtype,
                             iterations=int(Pars_in.iterations), periodic=Pars_in.periodic, dict_dim=DD, astart=aps,
                             aend=ape,
                             importance=Pars_in.importance)
            if SF90:
                T.create_mapF()
            else:
                T.create_map()
            T.evaluate_map(inputX=Train.X, inputY=Train.Y)
            T.save_map(kss, fileout=Train.Pars.somfilename, path=Train.Pars.path_output_maps)

        #OOB DATA
        if Pars_in.ooberror == 'yes':
            for io in xrange(len(Train.Xoob)):
                ij = Train.oob_index_or[io]
                tempo = T.get_vals(Train.Xoob[io])
                if tempo[0] != -1.:
                    if Cmode == 'Reg': OP0raw[ij, :] += Train_S.get_hist(tempo)
                    if Cmode == 'Class':
                        OS1[ij, 0] += sum(array(tempo))
                        OS1[ij, 1] += sum(array(tempo) * array(tempo))
                        OS1[ij, 2] += 1. * len(tempo)

        if Pars_in.varimportance == 'yes':
            for ka in xrange(len(Train.indx)):
                kname = Train.cols[Train.indx[ka]]
                k_index = where(array(Pars_in.att) == kname)[0][0]
                Xoob = copy.deepcopy(Train.Xoob)
                indexp = rn.sample(xrange(len(Xoob)), len(Xoob))
                Xoob[:, ka] = Train.Xoob[indexp, ka]
                for io in xrange(len(Train.Xoob)):
                    ij = Train.oob_index_or[io]
                    tempo = T.get_vals(Xoob[io])
                    if tempo[0] != -1.:
                        if Cmode == 'Reg': OP0rawV[ij, :, k_index] += Train_SV.get_hist(tempo)
                        if Cmode == 'Class':
                            OS1V[ij, 0, k_index] += sum(array(tempo))
                            OS1V[ij, 1, k_index] += sum(array(tempo) * array(tempo))
                            OS1V[ij, 2, k_index] += 1. * len(tempo)

    del T
    if PLL == 'MPI': comm.Barrier()
    del Train

    if Pars_in.ooberror == 'yes':
        if rank == 0:
            utils_mlz.printpz()
            utils_mlz.printpz("Cross validation with OOB data")
            utils_mlz.printpz()

        if Cmode == 'Reg':
            Oraw = join_oob(rank, OP0raw, PLL)
            del OP0raw
        if Cmode == 'Class':
            Oraw = join_oob(rank, OS1, PLL)
            del OS1
        if PLL == 'MPI': comm.Barrier()

        if rank == 0:
            Z0b2 = zeros((train_nobj, 7))
            if Cmode == 'Reg':
                BP0b2 = zeros((train_nobj, len(zfine2)))
                for k in xrange(train_nobj):
                    if sum(Oraw[k]) > 0.:
                        z_phot_t, pdf_phot_t = Train_S.get_pdf(Oraw[k], train_yvals[k])
                        Z0b2[k, :] = z_phot_t
                        BP0b2[k, :] = pdf_phot_t
                del Oraw
                analysis.save_single(Z0b2, Pars_in, oob='yes')
                analysis.save_PDF(zfine2, BP0b2, Pars_in, oob='yes')
                del Z0b2, BP0b2
            if Cmode == 'Class':
                z_0_t, z_1_t, s_0_t = analysis.class_stat(Oraw[:, 0], Oraw[:, 1], Oraw[:, 2], Pars_in)
                Z0b2[:, 0] = train_yvals
                Z0b2[:, 1] = z_0_t
                Z0b2[:, 2] = z_1_t
                Z0b2[:, 3] = s_0_t
                del Oraw
                analysis.save_single(Z0b2, Pars_in, oob='yes')
                del Z0b2

    if Pars_in.varimportance == 'yes':
        if rank == 0: utils_mlz.printpz("Computing importance ranking")
        if PLL == 'MPI': comm.Barrier()
        for ka in xrange(len(Pars_in.att)):
            if Cmode == 'Reg':
                Oraw = join_oob(rank, OP0rawV[:, :, ka], PLL)
            if Cmode == 'Class':
                Oraw = join_oob(rank, OS1V[:, :, ka], PLL)
            if PLL == 'MPI': comm.Barrier()
            if rank == 0:
                Z0b2 = zeros((train_nobj, 7))
                if Cmode == 'Reg':
                    for k in xrange(train_nobj):
                        if sum(Oraw[k]) > 0.:
                            z_phot_t, pdf_phot_t = Train_SV.get_pdf(Oraw[k], train_yvals[k])
                            Z0b2[k, :] = z_phot_t
                    analysis.save_single(Z0b2, Pars_in, oob='yes', var='_' + Pars_in.att[ka])
                    del Z0b2
                if Cmode == 'Class':
                    z_0_t, z_1_t, s_0_t = analysis.class_stat(Oraw[:, 0], Oraw[:, 1], Oraw[:, 2], Pars_in)
                    Z0b2[:, 0] = train_yvals
                    Z0b2[:, 1] = z_0_t
                    Z0b2[:, 2] = z_1_t
                    Z0b2[:, 3] = s_0_t
                    analysis.save_single(Z0b2, Pars_in, oob='yes', var='_' + Pars_in.att[ka])
                    del Z0b2
        del Oraw
        if Cmode == 'Reg': del OP0rawV
        if Cmode == 'Class': del OS1V
    if Pars_in.ooberror == 'yes': del train_yvals

    if PLL == 'MPI': comm.Barrier()
    if rank == 0:
        utils_mlz.printpz()
        utils_mlz.printpz('+-+-+-+-+-+-+-+-+-+-+-+-+')
        utils_mlz.printpz(PLL, ' time Training')
        utils_mlz.printpz('+-+-+-+-+-+-+-+-+-+-+-+-+')
        local_clock.elapsed()

if Pars_in.dotest == 'no':
    if rank == 0:
        utils_mlz.printpz()
        utils_mlz.printpz("Only training was selected")
        utils_mlz.printpz("Check utils/utils_mlz.py ")
        utils_mlz.printpz()
    if PLL == 'MPI': MPI.Finalize()
    sys.exit(0)

#------------------------------------------------------------------------
# NOW TEST
zfine, zfine2, resz, resz2, wzin = analysis.get_zbins(Pars_in)
zfine2 = zfine2[wzin]
ntot = int(Pars_in.nrandom * Pars_in.ntrees)
if rank == 0:
    local_clock = utils_mlz.Stopwatch('no')
    cat_temp = data.read_catalog(Pars_in.path_test + Pars_in.testfile, check=Pars_in.checkonly)
    Ng = array(len(cat_temp), 'i')
    del cat_temp
else:
    Ng = array(0, 'i')

if PLL == 'MPI': comm.Barrier()
if PLL == 'MPI': comm.Bcast([Ng, MPI.INT], root=0)

s0, s1 = utils_mlz.get_limits(Ng, Nproc, rank)
if rank == 0:
    utils_mlz.printpz('-> NUMBER OF GALAXIES IN TEST CATALOG : ', Ng)
    for i in xrange(Nproc):
        Xs_0, Xs_1 = utils_mlz.get_limits(Ng, Nproc, i)
        utils_mlz.printpz(Xs_0, ' ', Xs_1, ' -------------> to core ', i)

Test = data.catalog(Pars_in, cat_type='test', L1=s0, L2=s1, rank=rank)
Test.get_XY()

if Test.has_Y():
    yvals = Test.Y
else:
    yvals = zeros(Test.nobj)

test_nobj = Test.nobj

Z0 = zeros((Test.nobj, 7))

if Pmode == 'TPZ': path1 = Test.Pars.path_output_trees
if Pmode == 'TPZ_C': path1 = Test.Pars.path_output_trees
if Pmode == 'SOM': path1 = Test.Pars.path_output_maps

if Pmode == 'TPZ': fileb = Pars_in.treefilename
if Pmode == 'TPZ_C': fileb = Pars_in.treefilename
if Pmode == 'SOM': fileb = Pars_in.somfilename

if Cmode == 'Reg':
    BP0 = zeros((Test.nobj, len(zfine2)))
    BP0raw = zeros((Test.nobj, len(zfine) - 1))
    Test_S = analysis.GetPz_short(Pars_in)
if Cmode == 'Class':
    S1 = zeros(Test.nobj)
    S2 = zeros(Test.nobj)
    Nv = zeros(Test.nobj)

for k in xrange(ntot):
    ff = '_%04d' % k
    filec = path1 + fileb + ff + '.npy'
    S = load(filec)
    S = S.item()
    DD = S.dict_dim
    if Pmode == 'SOM': Test.get_XY(curr_at=DD)
    for i in xrange(Test.nobj):
        temp = S.get_vals(Test.X[i])
        if temp[0] != -1.:
            if Cmode == 'Reg': BP0raw[i, :] += Test_S.get_hist(temp)
            if Cmode == 'Class':
                S1[i] += sum(array(temp))
                S2[i] += sum(array(temp) * array(temp))
                Nv[i] += 1. * len(temp)

if Cmode == 'Reg':
    for k in xrange(test_nobj):
        z_phot, pdf_phot = Test_S.get_pdf(BP0raw[k], yvals[k])
        Z0[k, :] = z_phot
        BP0[k, :] = pdf_phot
    del BP0raw, yvals
if Cmode == 'Class':
    z_0, z_1, s_0 = analysis.class_stat(S1, S2, Nv, Pars_in)
    Z0[:, 0] = yvals
    Z0[:, 1] = z_0
    Z0[:, 2] = z_1
    Z0[:, 3] = s_0
    del S1, S2, Nv

if rank == 0:
    utils_mlz.printpz()
    utils_mlz.printpz('+-+-+-+-+-+-+-+-+-+-+-+-+')
    utils_mlz.printpz(PLL, ' time Testing')
    utils_mlz.printpz('+-+-+-+-+-+-+-+-+-+-+-+-+')
    local_clock.elapsed()

if PLL == 'MPI': comm.Barrier()
####################################


s0, s1 = utils_mlz.get_limits(Ng, Nproc, rank)

if rank == 0:
    BIGZ = zeros((Ng, 7))
    BIGZ[s0:s1, :] = Z0
    for srank in xrange(1, Nproc):
        s0, s1 = utils_mlz.get_limits(Ng, Nproc, srank)
        size_dat = s1 - s0
        ZT = zeros((size_dat, 7))
        if PLL == 'MPI': comm.Recv(ZT, source=srank, tag=srank * 2)
        BIGZ[s0:s1, :] = ZT
        del ZT
else:
    if PLL == 'MPI': comm.Send(Z0, dest=0, tag=rank * 2)
    del Z0

path_r, filebase_r, num_r = analysis.get_path_new(Pars_in)
if PLL == 'MPI': comm.Barrier()

if rank == 0:
    analysis.save_single(BIGZ, Pars_in)
    del BIGZ

if Pars_in.sparserep == 'yes':
    if rank == 0: utils_mlz.printpz("Sparse Representation for PDFs")
    mu = [min(zfine2), max(zfine2)]
    Nmu, Nsig, Nv = map(int, Pars_in.sparsedims)
    Ncoef = int(Pars_in.numbercoef)
    Nsparse = int(Pars_in.numberbases)
    AA = linspace(0, 1, Ncoef)
    dz_sp = zfine2[1] - zfine2[0]
    Da = AA[1] - AA[0]
    max_sig = (max(zfine2) - min(zfine2)) / 12.
    min_sig = dz_sp / 6.
    sig = [min_sig, max_sig]
    NA = Nmu * Nsig * Nv

    head = pf.Header()
    head['N_MU'] = Nmu
    head['N_SIGMA'] = Nsig
    head['N_VOIGT'] = Nv
    head['N_COEF'] = Ncoef
    head['N_SPARSE'] = Nsparse
    head['MU1'] = mu[0]
    head['MU2'] = mu[1]
    head['SIGMA1'] = sig[0]
    head['SIGMA2'] = sig[1]

    if rank == 0:
        utils_mlz.printpz("Nmu, Nsigm Nv = [ ", Nmu, ", ", Nsig, ", ", Nv, " ]")
        utils_mlz.printpz("Number of total bases in dictionary : ", NA)
        utils_mlz.printpz("Number of bases sparse representation : ", Nsparse)
        utils_mlz.printpz()
        utils_mlz.printpz("Creating Dictionary...")
    AD = pdf_storage.create_voigt_dict(zfine2, mu, Nmu, sig, Nsig, Nv)
    if rank == 0:
        utils_mlz.printpz("Creating Sparse Representation...")
    SPR = zeros((len(BP0), Nsparse), dtype='int32')
    for ip in xrange(len(BP0)):
        pdf0 = BP0[ip]
        if sum(pdf0) > 0:
            pdf0 /= sum(pdf0)
        else:
            continue
        Dind, Dval = pdf_storage.sparse_basis(AD, pdf0, Nsparse)
        if len(Dind) <= 1: continue
        if max(Dval) > 0:
            Dvalm = Dval / max(Dval)
            indexm = array(map(round, (Dvalm / Da)), dtype='int')
        else:
            indexm = zeros(len(Dind), dtype='int')
        temp_int = array(map(pdf_storage.combine_int, indexm, Dind))
        SPR[ip, 0:len(temp_int)] = temp_int
        AD[:, [Dind]] = AD[:, [arange(len(Dind))]]
        del temp_int, indexm, Dind, Dval, Dvalm
if PLL == 'MPI': comm.Barrier()

if Cmode == 'Reg' and Pars_in.writepdf == 'yes':
    if Pars_in.multiplefiles == 'yes':
        if rank == 0 and PLL == 'MPI':
            utils_mlz.printpz()
            utils_mlz.printpz('************************************')
            utils_mlz.printpz('** Writing multiple file for PDFs **')
            utils_mlz.printpz('************************************')
            utils_mlz.printpz()
        if Pars_in.originalpdffile == 'yes':
            analysis.save_PDF(zfine2, BP0, Pars_in, path=path_r, filebase=filebase_r, num=num_r, multiple='yes',
                              rank=rank)
            del BP0
        if Pars_in.sparserep == 'yes':
            head['N_TOT'] = len(SPR)
            analysis.save_PDF_sparse(zfine2, SPR, head, Pars_in, path=path_r, filebase=filebase_r, num=num_r,
                                     multiple='yes', rank=rank)
            del SPR
    else:
        if rank == 0:
            utils_mlz.printpz()
            utils_mlz.printpz('*************************************')
            utils_mlz.printpz('** Collecting data from processors **')
            utils_mlz.printpz('*************************************')
            utils_mlz.printpz()
        s0, s1 = utils_mlz.get_limits(Ng, Nproc, rank)
        if Pars_in.originalpdffile == 'yes':
            if rank == 0:
                utils_mlz.printpz("Saving Original PDFs")
                utils_mlz.printpz()
                BIGP = zeros((Ng, len(zfine2)))
                BIGP[s0:s1, :] = BP0
                for srank in xrange(1, Nproc):
                    s0, s1 = utils_mlz.get_limits(Ng, Nproc, srank)
                    size_dat = s1 - s0
                    BPT = zeros((size_dat, len(zfine2)))
                    if PLL == 'MPI': comm.Recv(BPT, source=srank, tag=srank * 2)
                    BIGP[s0:s1, :] = BPT
                    del BPT
            else:
                if PLL == 'MPI':
                    comm.Send(BP0, dest=0, tag=rank * 2)
                    del BP0
            if PLL == 'MPI': comm.Barrier()
            if rank == 0:
                analysis.save_PDF(zfine2, BIGP, Pars_in, path=path_r, filebase=filebase_r, num=num_r)
                del BIGP

        if Pars_in.sparserep == 'yes':
            s0, s1 = utils_mlz.get_limits(Ng, Nproc, rank)
            if rank == 0:
                utils_mlz.printpz("Saving Sparse Represenation PDFs")
                utils_mlz.printpz()
                BIGP = zeros((Ng, Nsparse), dtype='int32')
                BIGP[s0:s1, :] = SPR
                for srank in xrange(1, Nproc):
                    s0, s1 = utils_mlz.get_limits(Ng, Nproc, srank)
                    size_dat = s1 - s0
                    BPT = zeros((size_dat, Nsparse), dtype='int32')
                    if PLL == 'MPI': comm.Recv(BPT, source=srank, tag=srank * 2)
                    BIGP[s0:s1, :] = BPT
                    del BPT
            else:
                if PLL == 'MPI':
                    comm.Send(SPR, dest=0, tag=rank * 2)
                    del SPR
            if PLL == 'MPI': comm.Barrier()
            if rank == 0:
                head['N_TOT'] = int(Ng)
                analysis.save_PDF_sparse(zfine2, BIGP, head, Pars_in, path=path_r, filebase=filebase_r, num=num_r)
                del BIGP

if rank == 0:
    utils_mlz.printpz('+-+-+-+-+-+-+-+-+')
    utils_mlz.printpz(PLL, ' TOTAL TIME')
    utils_mlz.printpz('+-+-+-+-+-+-+-+-+')
    clock_all.elapsed()
if PLL == 'MPI': MPI.Finalize()


