#!/usr/bin/env python
__author__ = 'Matias Carrasco Kind'
from numpy import *
from scipy.interpolate import interp1d as spl
import sys, os
import utils_mlz
import pyfits as pf

try:
    from mpi4py import MPI

    PLL = 'MPI'
except:
    PLL = 'SERIAL'
#MPI start

if PLL == 'MPI':
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()
else:
    size = 1
    rank = 0

if len(sys.argv) < 2:
    if rank == 0:
        utils_mlz.printpz()
        utils_mlz.printpz("Usage:: ")
        utils_mlz.printpz("-----------------------------------------------------------------------------")
        utils_mlz.printpz("./use_pdfs <input file> <run No (def=0)> <zConf (def = 0)> <Nbins (def = 30)>")
        utils_mlz.printpz("-----------------------------------------------------------------------------")
        utils_mlz.printpz()
        utils_mlz.printpz("<input file>  : Same used in runMLZ")
        utils_mlz.printpz("<run No>      : result id on results/ folder , default is 0")
        utils_mlz.printpz("<zConf>       : confidence level for PDFs , default is 0")
        utils_mlz.printpz("<Nbins>       : Number of bins in N(z) , default is 30")
        utils_mlz.printpz()
        utils_mlz.printpz("** Can be run in parallel! **")
        utils_mlz.printpz()

    sys.exit(0)

if rank == 0:
    utils_mlz.print_welcome()
    if PLL == 'SERIAL':
        utils_mlz.printpz("***********************************************")
        utils_mlz.printpz("mpi4py not found, using serial version instead!")
        utils_mlz.printpz("***********************************************")
    clock_all = utils_mlz.Stopwatch()
FilePars = sys.argv[1]
verb = True
if rank != 0: verb = False
Pars = utils_mlz.read_dt_pars(FilePars, verbose=verb)

path_tree = Pars.path_output
path_results = Pars.path_results
filebase = Pars.finalfilename

ir = 0
zconf = 0.
Nbins = 30
if len(sys.argv) > 2:
    ir = int(sys.argv[2])
if len(sys.argv) > 3:
    zconf = float(sys.argv[3])
if len(sys.argv) > 4:
    Nbins = int(sys.argv[4])

fileroot = path_results + filebase + '.' + str(ir) + '.'

zt, zCb = loadtxt(fileroot + 'mlz', unpack=True, usecols=(0, 4))
ntot = len(zt)

if PLL == 'MPI': comm.Barrier()

L0, L1 = utils_mlz.get_limits(ntot, size, rank)

if rank == 0:
    utils_mlz.printpz('-> NUMBER OF GALAXIES IN TEST CATALOG : ', ntot)
    for i in xrange(size):
        Xs_0, Xs_1 = utils_mlz.get_limits(ntot, size, i)
        utils_mlz.printpz(Xs_0, ' ', Xs_1, ' -------------> to core ', i)

if Pars.multiplefiles == 'yes':
    utils_mlz.printpz()
    if rank == 0: utils_mlz.printpz("Reading from multiple files")
    if Pars.writefits == 'no':
        Bpdf = load(fileroot + 'P_' + str(rank) + '.npy')
    else:
        Temp = pf.open(fileroot + 'P_' + str(rank) + '.fits')
        Bpdf = Temp[1].data.field('PDF values')
        Temp.close()
    zfine = Bpdf[-1]
    spdf = Bpdf[:-1]
    del Bpdf
else:
    if Pars.writefits == 'no':
        Bpdf = load(fileroot + 'P.npy')
    else:
        Temp = pf.open(fileroot + 'P.fits')
        Bpdf = Temp[1].data.field('PDF values')
        Temp.close()
    zfine = Bpdf[-1]
    spdf = Bpdf[L0:L1]
    del Bpdf

zspec = zt[L0:L1]
zC = zCb[L0:L1]
del zt, zCb

H1 = zeros((Nbins, Nbins))

minz = Pars.minz
maxz = Pars.maxz

Nz = linspace(minz, maxz, Nbins + 1)
Nzmid = 0.5 * (Nz[1:] + Nz[:-1])
N_z = zeros(len(Nz) - 1)

if rank == 0:
    utils_mlz.printpz()
    utils_mlz.printpz("Creating map and computing N(z) ...")
for i in xrange(len(spdf)):
    if zC[i] < zconf: continue
    i1 = utils_mlz.inbin(zspec[i], minz, maxz, Nbins)
    pdf = spdf[i]
    area = utils_mlz.get_prob_Nz(zfine, pdf, Nz)
    H1[:, i1] += area
    N_z += area

if rank == 0:
    for srank in xrange(1, size):
        temp_map = zeros((Nbins, Nbins))
        if PLL == 'MPI': comm.Recv(temp_map, source=srank, tag=srank * 2)
        H1 += temp_map
else:
    if PLL == 'MPI': comm.Send(H1, dest=0, tag=rank * 2)

if PLL == 'MPI': comm.Barrier()

if rank == 0:
    for srank in xrange(1, size):
        temp_Nz = zeros(Nbins)
        if PLL == 'MPI': comm.Recv(temp_Nz, source=srank, tag=srank * 2)
        N_z += temp_Nz
else:
    if PLL == 'MPI': comm.Send(N_z, dest=0, tag=rank * 2)

if PLL == 'MPI': comm.Barrier()

if rank == 0:
    save(path_results + filebase + '.' + str(ir) + '_map', H1)
    N_z = N_z / sum(N_z)
    savetxt(path_results + filebase + '.' + str(ir) + '_zdist', zip(Nzmid, N_z), fmt='%.6f')

    if rank == 0:
        clock_all.elapsed()
        utils_mlz.printpz()
        utils_mlz.printpz("N(zphot) saved in ", path_results + filebase + '.' + str(ir) + '_zdist')
        utils_mlz.printpz("Zspec vs Zphot map saved in  ", path_results + filebase + '.' + str(ir) + '_map')

if PLL == 'MPI': MPI.Finalize()
