#!/usr/bin/env python
"""
Computes a multimodel mean.
"""
from ILAMB.Scoreboard import Scoreboard
from ILAMB.constants import bnd_months
from ILAMB.Variable import Variable
import ILAMB.ilamblib as il
from netCDF4 import Dataset
from sympy import sympify
import ILAMB.run as r
import numpy as np
import argparse
import pickle
import sys
import os

def Interpolate(v,t,d,lat,lon):
    args = []
    if   t is not None: args.append((np.abs(  t[:,np.newaxis]-v.time )).argmin(axis=1))
    if   d is not None: args.append((np.abs(  d[:,np.newaxis]-v.depth)).argmin(axis=1))
    if lat is not None: args.append((np.abs(lat[:,np.newaxis]-v.lat  )).argmin(axis=1))
    if lon is not None: args.append((np.abs(lon[:,np.newaxis]-v.lon  )).argmin(axis=1))
    return v.data[np.ix_(*args)]

def MultiModelMean(M,vname,wgt_file=None,maxV=20):

    wgts = None
    if wgt_file is not None:
        with open(wgt_file[0],"rb") as f:
            wgts = pickle.load(f)
            print("Loaded weights:",wgts)
            
    sys.stdout.write(('{0:>%d} ' % maxV).format(vname)); sys.stdout.flush()

    # Based on the variable, we need to set up a space-time grid. So
    # first we need find a model with the variable and then check what
    # kind of variable it is.
    Ms = [m for m in M if vname in m.variables]
    if len(Ms) == 0:
        sys.stdout.write('\n'); sys.stdout.flush()
        return
    if vname == "sftlf":
        v = Variable(filename = Ms[0].variables['sftlf'][0],variable_name = 'sftlf').convert("1")
    else:
        v = Ms[0].extractTimeSeries(vname)
    unit = v.unit
    shp = []

    # If the variable is temporal
    t = None; t_bnd = None; t0 = -1e20; tf = 1e20
    if v.temporal:
        y0 = 1850 if vname == "nbp" else 1950
        yrs = (np.asarray(range(y0,2016),dtype=float)-1850)*365
        yrs = (yrs[:,np.newaxis] + bnd_months[:12]).flatten()
        t_bnd = np.asarray([yrs[:-1],yrs[1:]]).T
        t = t_bnd.mean(axis=1)
        shp += [t.size,]
        t0 = t_bnd[0,0]; tf = t_bnd[-1,1]

    # If the variable is layered
    d = None; d_bnd = None
    if v.layered:
        d_bnd = np.asarray([[0.000, 0.025],
                            [0.025, 0.065],
                            [0.065, 0.125],
                            [0.125, 0.21],
                            [0.21, 0.33],
                            [0.33, 0.49],
                            [0.49, 0.69],
                            [0.69, 0.9299999],
                            [0.9299999, 1.21],
                            [1.21, 1.53],
                            [1.53, 1.89],
                            [1.89, 2.29],
                            [2.29, 2.745],
                            [2.745, 3.285],
                            [3.285, 3.925],
                            [3.925, 4.665],
                            [4.665, 5.505]])
        d = d_bnd.mean(axis=1)
        shp += [d.size,]

    # If the variable is spatial
    lat = None; lat_bnd = None
    lon = None; lon_bnd = None
    if v.spatial:
        lat_bnd,lon_bnd,lat,lon = il.GlobalLatLonGrid(args.res)
        lat_bnd = np.asarray([lat_bnd[:-1],lat_bnd[1:]]).T
        lon_bnd = np.asarray([lon_bnd[:-1],lon_bnd[1:]]).T
        shp += [lat.size,lon.size]

    # Let's start summing it up
    data  = np.zeros(shp)
    count = np.zeros(shp,dtype=int)
    sumw  = 0.
    for m in M:
        sys.stdout.write('.'); sys.stdout.flush()
        w = wgts[m.name] if wgts is not None else 1.
        
        vname0 = vname
        if vname == "cSoil" and "cSoilAbove1m" in m.variables: vname = "cSoilAbove1m"
        if vname == "burntArea" and "burntFractionAll" in m.variables: vname = "burntFractionAll"
        
        try:
            if vname == "sftlf":
                v = Variable(filename = m.variables['sftlf'][0],variable_name = 'sftlf').convert("1")
            else:
                v = m.extractTimeSeries(vname,initial_time=t0,final_time=tf).convert(unit)
        except:
            continue
        v = Interpolate(v,t,d,lat,lon)
        data  += v.data*(v.mask==False)*w
        count +=        (v.mask==False)
        sumw  += w
        vname = vname0

    # Take the mean and write it out
    with np.errstate(all='ignore'):
        data = np.ma.masked_array(data=(data/sumw),mask=(count==0))
    with Dataset(os.path.join(args.build_dir[0],"%s.nc" % vname),mode="w") as dset:
        Variable(data       = data,
                 unit       = unit,
                 name       = vname,
                 time       = t,
                 time_bnds  = t_bnd,
                 lat        = lat,
                 lat_bnds   = lat_bnd,
                 lon        = lon,
                 lon_bnds   = lon_bnd,
                 depth      = d,
                 depth_bnds = d_bnd).toNetCDF4(dset)

    sys.stdout.write('\n'); sys.stdout.flush()

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--model_root', dest="model_root", metavar='root', type=str, nargs=1, default=["./"],
                    help='root at which to search for models')
parser.add_argument('--build_dir', dest="build_dir", metavar='build_dir', type=str, nargs=1,default=["./_build"],
                    help='path of where to save the output')
parser.add_argument('--config', dest="config", metavar='config', type=str, nargs=1,
                    help='path to configuration file to use')
parser.add_argument('--wgt', dest="wgt", metavar='wgt', type=str, nargs=1,
                    help='path to weight file to use')
parser.add_argument('--models', dest="models", metavar='m', type=str, nargs='+',default=[],
                    help='specify which models to run, list model names with no quotes and only separated by a space.')
parser.add_argument('--vars', dest="vars", metavar='v', type=str, nargs='+',default=[],
                    help='list of the variables to take the mean of')
parser.add_argument('-q','--quiet', dest="quiet", action="store_true",
                    help='enable to silence screen output')
parser.add_argument('--filter', dest="filter", metavar='filter', type=str, nargs=1, default=[""],
                    help='a string which much be in the model filenames')
parser.add_argument('--regex', dest="regex", metavar='regex', type=str, nargs=1, default=[""],
                    help='a regular expression which filenames must conform to in order to be included')
parser.add_argument('--model_setup', dest="model_setup", type=str, nargs='+',default=None,
                    help='list files model setup information')
parser.add_argument('-g','--same_grid', dest="same_grid", action="store_true",
                    help='enable if all models are on the same grid')
parser.add_argument('-r','--res', dest="res", type=float, default=1.0,
                    help='enable if all models are on the same grid')

# Setup models to be used
args = parser.parse_args()
if not os.path.isdir(args.build_dir[0]): os.makedirs(args.build_dir[0])
if args.model_setup is None:
    M = r.InitializeModels(args.model_root[0],
                           args.models,
                           False,
                           filter=args.filter[0],
                           regex=args.regex[0],
                           models_path=args.build_dir[0])
else:
    M = r.ParseModelSetup(args.model_setup[0],args.models,False,filter=args.filter[0],models_path=args.build_dir[0])
M = [m for m in M if "Mean" not in m.name]
if not args.quiet:
    print("\nTaking mean of the following models:\n")
    for m in M: print("{0:>20}".format(m.name))

# What variables do we need the means of?
Vs = []
if args.config is None:
    for m in M: Vs += [v for v in m.variables.keys() if v not in Vs]
else:
    if not args.quiet: print("\nParsing config file %s...\n" % args.config[0])
    S = Scoreboard(args.config[0],
                   master    = True,
                   verbose   = not args.quiet,
                   build_dir = args.build_dir[0])
    for c in S.list():
        vs  = [c.variable,]
        vs +=  c.alternate_vars
        if c.derived is not None: vs += [str(s) for s in sympify(c.derived).free_symbols]
        Vs += [v for v in vs if v not in Vs]
    if "co2" in Vs: Vs.pop(Vs.index("co2"))
    if "cSoilAbove1m" in Vs: Vs.pop(Vs.index("cSoilAbove1m"))
    if "burntFractionAll" in Vs: Vs.pop(Vs.index("burntFractionAll"))
if len(args.vars) > 0: Vs = args.vars
Vs = ['sftlf',] + Vs
if not args.quiet: print("\nTaking the mean of [%s]...\n" % (", ".join(Vs)))

for vname in Vs:
    MultiModelMean(M,vname,args.wgt)

