#!python
"""
Computes a multimodel mean.
"""
import os,sys,argparse
from ILAMB.ModelResult import ModelResult
from ILAMB.Variable import Variable
import ILAMB.ilamblib as il
from ILAMB.constants import bnd_months
from netCDF4 import Dataset
import numpy as np
from traceback import format_exc

def InitializeModels(model_root,models=[],verbose=False,filter="",model_year=[]):
    """Initializes a list of models

    Initializes a list of models where each model is the subdirectory
    beneath the given model root directory. The global list of models
    will exist on each processor.

    Parameters
    ----------
    model_root : str
        the directory whose subdirectories will become the model results
    models : list of str, optional
        only initialize a model whose name is in this list
    verbose : bool, optional
        enable to print information to the screen
    model_year : 2-tuple, optional
        shift model years from the first to the second part of the tuple

    Returns
    -------
    M : list of ILAMB.ModelResults.ModelsResults
       a list of the model results, sorted alphabetically by name

    """
    # initialize the models
    M = []
    if len(model_year) != 2: model_year = None
    max_model_name_len = 0
    if verbose: print("\nSearching for model results in %s\n" % model_root)
    for subdir, dirs, files in os.walk(model_root):
        for mname in dirs:
            if len(models) > 0 and mname not in models: continue
            try:
                m = ModelResult(os.path.join(subdir,mname), modelname = mname, filter=filter, model_year = model_year)
            except Exception as ex:
                print(format_exc())
                continue
            M.append(m)
            max_model_name_len = max(max_model_name_len,len(mname))
        break
    M = sorted(M,key=lambda m: m.name.upper())

    # optionally output models which were found
    if verbose:
        for m in M:
            print(("    {0:>45}").format(m.name))

    if len(M) == 0:
        if verbose: print("No model results found")
        sys.exit(0)

    return M


def ComputeMeanSameGrid(vname,M,outfile):
    t = None
    count = 0
    for i,m in enumerate(M):
        try:
            v = m.extractTimeSeries(vname)
        except:
            continue
        if t is None:
            t = v.time; tb = v.time_bnds
            x = v.lat ; xb = v.lat_bnds
            y = v.lon ; yb = v.lon_bnds
            a = v.area; d  = v.data
            u = v.unit
        else:
            d += v.data
        assert np.allclose(t ,v.time     )
        assert np.allclose(tb,v.time_bnds)
        assert np.allclose(x ,v.lat      )
        assert np.allclose(xb,v.lat_bnds )
        assert np.allclose(y ,v.lon      )
        assert np.allclose(yb,v.lon_bnds )
        assert np.allclose(a ,v.area     )
        count += 1
    d /= count
    with Dataset(outfile,mode="w") as dset:
        v = Variable(name = vname,
                     unit = u,
                     data = d,
                     time = t, time_bnds = tb,
                     lat  = x,  lat_bnds = xb,
                     lon  = y,  lon_bnds = yb,
                     area = a)
        v.toNetCDF4(dset)
    with Dataset(outfile.replace(vname,"areacella"),mode="w") as dset:
        u = Variable(name = "areacella",
                     unit = "m2",
                     data = a,
                     lat  = x,  lat_bnds = xb,
                     lon  = y,  lon_bnds = yb)
        u.toNetCDF4(dset)
    return v

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--model_root', dest="model_root", metavar='root', type=str, nargs=1, default=["./"],
                    help='root at which to search for models')
parser.add_argument('--output_dir', dest="output", metavar='output', type=str, nargs=1, default=["./"],
                    help='root at which to save model output')
parser.add_argument('--models', dest="models", metavar='m', type=str, nargs='+',default=[],
                    help='specify which models to run, list model names with no quotes and only separated by a space.')
parser.add_argument('--filter', dest="filter", metavar='filter', type=str, nargs=1, default=[""],
                    help='a string which much be in the model filenames')
parser.add_argument('-q','--quiet', dest="quiet", action="store_true",
                    help='enable to silence screen output')
parser.add_argument('-g','--same_grid', dest="same_grid", action="store_true",
                    help='enable if all models are on the same grid')
parser.add_argument('-r','--res', dest="res", type=float, default=1.0,
                    help='enable if all models are on the same grid')

# setup
args = parser.parse_args()
M  = InitializeModels(args.model_root[0],args.models,not args.quiet,filter=args.filter[0])
Vs = []
for m in M:
    Vs += [v for v in m.variables.keys() if v not in Vs]
skip = ['time','time_bnds','lat','lat_bnds','lon','lon_bnds','plev','time_bounds','type','nav_lat','nav_lon','bounds_nav_lon','bounds_nav_lat','area','areacello','areacella','mcdata']
Vs = [v for v in Vs if v not in skip]
if args.same_grid and 'sftlf' in Vs: Vs.pop(Vs.index('sftlf'))
if not args.quiet:
    print("Found variables: %s" % (",".join(Vs)))

if args.same_grid:
    for vname in Vs:
        print("Processing %s..." % vname)
        ComputeMeanSameGrid(vname,M,os.path.join(args.output[0],"%s_mean.nc" % vname))
else:
    # compute space-time grid on which everything will be interpolated
    yrs = (np.asarray(range(1850,2016),dtype=float)-1850)*365
    yrs = (yrs[:,np.newaxis] + bnd_months[:12]).flatten()
    t_bnd = np.asarray([yrs[:-1],yrs[1:]]).T
    t = t_bnd.mean(axis=1)
    lat_bnd,lon_bnd,lat,lon = il.GlobalLatLonGrid(args.res)
    lat_bnd = np.asarray([lat_bnd[:-1],lat_bnd[1:]]).T
    lon_bnd = np.asarray([lon_bnd[:-1],lon_bnd[1:]]).T

    for v in Vs:
        sys.stdout.write('\nProcessing %s.' % v)
        data  = np.zeros((t.size,lat.size,lon.size))
        count = np.zeros((t.size,lat.size,lon.size),dtype=int)

        skip = False
        unit = None
        for i,m in enumerate(M):
            sys.stdout.write('.');sys.stdout.flush()
            if skip: continue
            try:
                u = m.extractTimeSeries(v)
            except:
                if v == 'sftlf':
                    if 'sftlf' not in m.variables: continue
                    u = Variable(filename = m.variables['sftlf'][0],variable_name = 'sftlf')
                    if u.data.max() > 1: u.data /= 100.
                else:
                    continue
            if u.layered or u.data.ndim > 3 or u.data.ndim == 1:
                skip = True
                continue
            if unit is None: unit = u.unit
            u.convert(unit)
            sys.stdout.write('.');sys.stdout.flush()
            ui = u.interpolate(time=t if u.temporal else None,
                               lat=lat,lon=lon,lat_bnds=lat_bnd,lon_bnds=lon_bnd)
            sys.stdout.write('.');sys.stdout.flush()
            with np.errstate(under='ignore',over='ignore'):
                if u.temporal:
                    if data is None:
                        data  = np.zeros(ui.data.shape)
                        count = np.zeros(ui.data.shape,dtype=int)
                    data  +=  ui.data.data*(ui.data.mask==False)
                    count += (ui.data.mask==False)
                else:
                    if data is None:
                        data  = np.zeros((1,)+ui.data.shape)
                        count = np.zeros((1,)+ui.data.shape,dtype=int)
                    data [0,...] +=  ui.data.data*(ui.data.mask==False)
                    count[0,...] += (ui.data.mask==False)

        if count.sum() == 0: continue
        sys.stdout.write('.');sys.stdout.flush()
        with np.errstate(all='ignore'):
            data = np.ma.masked_array(data=(data/count.clip(1)),mask=(count==0))
        with Dataset(os.path.join(args.output[0],"%s_mean.nc" % v),mode="w") as dset:
            o = Variable(data       = data if u.temporal else data[0,...],
                         unit       = unit,
                         name       = v,
                         time       = t     if u.temporal else None,
                         time_bnds  = t_bnd if u.temporal else None,
                         lat        = lat,
                         lat_bnds   = lat_bnd,
                         lon        = lon,
                         lon_bnds   = lon_bnd).toNetCDF4(dset)
        sys.stdout.flush()
