#!/usr/bin/env python
"""
Computes a multimodel mean.
"""
import os,sys,argparse
from ILAMB.ModelResult import ModelResult
from ILAMB.Variable import Variable
from netCDF4 import Dataset
import numpy as np

def InitializeModels(model_root,models=[],verbose=False,filter="",model_year=[]):
    """Initializes a list of models

    Initializes a list of models where each model is the subdirectory
    beneath the given model root directory. The global list of models
    will exist on each processor.

    Parameters
    ----------
    model_root : str
        the directory whose subdirectories will become the model results
    models : list of str, optional
        only initialize a model whose name is in this list
    verbose : bool, optional
        enable to print information to the screen
    model_year : 2-tuple, optional
        shift model years from the first to the second part of the tuple

    Returns
    -------
    M : list of ILAMB.ModelResults.ModelsResults
       a list of the model results, sorted alphabetically by name

    """
    # initialize the models
    M = []
    if len(model_year) != 2: model_year = None
    max_model_name_len = 0
    if verbose: print "\nSearching for model results in %s\n" % model_root
    for subdir, dirs, files in os.walk(model_root):
        for mname in dirs:
            if len(models) > 0 and mname not in models: continue
            M.append(ModelResult(os.path.join(subdir,mname), modelname = mname, filter=filter, model_year = model_year))
            max_model_name_len = max(max_model_name_len,len(mname))
        break
    M = sorted(M,key=lambda m: m.name.upper())

    # optionally output models which were found
    if verbose:
        for m in M:
            print ("    {0:>45}").format(m.name)

    if len(M) == 0:
        if verbose: print "No model results found"
        sys.exit(0)
        
    return M


def CombineModelVars(var,res=1.0):
    """

    """
    m0   = var.keys()[0]
    t0   = -1e20; tf   = +1e20
    lat0 = -1e20; latf = +1e20
    lon0 = -1e20; lonf = +1e20
    dep0 = -1e20; depf = +1e20
    for m in var:
        var[m].convert(var[m0].unit)
        assert var[m].data.ndim == var[m0].data.ndim
        if var[m].time is not None: 
            t0 = max(var[m].time_bnds.min(),t0)
            tf = min(var[m].time_bnds.max(),tf)
        if var[m].lat is not None: 
            lat0 = max(var[m].lat_bnds.min(),lat0)
            latf = min(var[m].lat_bnds.max(),latf)
        if var[m].lon is not None: 
            lon0 = max(var[m].lon_bnds.min(),lon0)
            lonf = min(var[m].lon_bnds.max(),lonf)
        if var[m].depth is not None: 
            dep0 = max(var[m].depth_bnds.min(),dep0)
            depf = min(var[m].depth_bnds.max(),depf)
    lat0 = max(lat0,- 90.); latf = min(latf,+ 90.)
    lon0 = max(lon0,-180.); lonf = min(lonf,+180.)

    # Create space/time grid
    var[m0].trim(t=[t0,tf])
    t            = np.copy(var[m0].time)
    t_bnds       = np.copy(var[m0].time_bnds)
    lat_bnds     = np.arange(lat0,latf+res/2.,res)
    lon_bnds     = np.arange(lon0,lonf+res/2.,res)
    lat          = 0.5*(lat_bnds[:-1]+lat_bnds[1:])
    lon          = 0.5*(lon_bnds[:-1]+lon_bnds[1:])
    lat_bnd      = np.zeros((lat.size,2))
    lon_bnd      = np.zeros((lon.size,2))
    lat_bnd[:,0] = lat_bnds[:-1]; lat_bnd[:,1] = lat_bnds[+1:]
    lon_bnd[:,0] = lon_bnds[:-1]; lon_bnd[:,1] = lon_bnds[+1:]

    shp = ()
    if t0   > -1e20: shp += (t  .size,)
    if lat0 > -1e20: shp += (lat.size,)
    if lon0 > -1e20: shp += (lon.size,)    
    dsum  = np.zeros(shp)
    count = np.zeros(shp,dtype=int)

    for m in var:
        print "   Averaging in %s..." % m
        intv   = var[m].interpolate(time=t,lat=lat,lon=lon)
        dsum  += (intv.data.mask==0)*intv.data
        count += (intv.data.mask==0)

    dsum  = np.ma.masked_array(dsum,mask=(count==0))
    dsum /= count.clip(1)
    return Variable(data       = dsum,
                    unit       = var[m0].unit,
                    name       = var[m0].name,
                    time       = t,
                    time_bnds  = t_bnds,
                    lat        = lat,
                    lat_bnds   = lat_bnd,
                    lon        = lon,
                    lon_bnds   = lon_bnd)
    
def CreateMeanModel(M,res=1.0):
    """

    """
    def _keep(v):
        for keep in ["_bnds","time","lat","lon","layer","depth","lev","areacella"]:
            if keep in v: return False
        return True
    
    # Find a list of variables across all models
    Vs = []
    for m in M: Vs += [v for v in m.variables.keys() if ((v not in Vs) and (_keep(v)))]
    
    # Create space/time grid
    lat_bnds     = np.arange(- 90, 90+res/2.,res)
    lon_bnds     = np.arange(-180,180+res/2.,res)
    lat          = 0.5*(lat_bnds[:-1]+lat_bnds[1:])
    lon          = 0.5*(lon_bnds[:-1]+lon_bnds[1:])
    lat_bnd      = np.zeros((lat.size,2))
    lon_bnd      = np.zeros((lon.size,2))
    lat_bnd[:,0] = lat_bnds[:-1]; lat_bnd[:,1] = lat_bnds[+1:]
    lon_bnd[:,0] = lon_bnds[:-1]; lon_bnd[:,1] = lon_bnds[+1:]

    for v in Vs:
        print v
        var = {}
        try:
            for m in M:
                if not m.variables.has_key(v): continue
                print "   Reading from %s..." % m.name
                var[m.name] = m.extractTimeSeries(v)
            if len(var) == 1: raise ValueError
            mean = CombineModelVars(var)
            with Dataset("%s_MeanModel.nc" % v,mode="w") as dset:
                mean.toNetCDF4(dset)
                print "Writing %s_MeanModel.nc...\n" % v
        except:
            print "Failed to create %s\n" % v
            
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--model_root', dest="model_root", metavar='root', type=str, nargs=1, default=["./"],
                    help='root at which to search for models')
parser.add_argument('--models', dest="models", metavar='m', type=str, nargs='+',default=[],
                    help='specify which models to run, list model names with no quotes and only separated by a space.')
parser.add_argument('--filter', dest="filter", metavar='filter', type=str, nargs=1, default=[""],
                    help='a string which much be in the model filenames')
parser.add_argument('-q','--quiet', dest="quiet", action="store_true",
                    help='enable to silence screen output')

args = parser.parse_args()

M  = InitializeModels(args.model_root[0],args.models,not args.quiet,filter=args.filter[0])
CreateMeanModel(M)
