#!/usr/bin/env python
"""
Computes a multimodel mean.
"""
import argparse
import os
import pickle
import sys

import ILAMB.ilamblib as il
import ILAMB.run as r
import numpy as np
from ILAMB.constants import bnd_months
from ILAMB.Scoreboard import Scoreboard
from ILAMB.Variable import Variable
from netCDF4 import Dataset
from sympy import sympify


def Interpolate(v, t, d, lat, lon):
    args = []
    if t is not None:
        args.append((np.abs(t[:, np.newaxis] - v.time)).argmin(axis=1))
    if d is not None:
        args.append((np.abs(d[:, np.newaxis] - v.depth)).argmin(axis=1))
    if lat is not None:
        args.append((np.abs(lat[:, np.newaxis] - v.lat)).argmin(axis=1))
    if lon is not None:
        args.append((np.abs(lon[:, np.newaxis] - v.lon)).argmin(axis=1))
    return v.data[np.ix_(*args)]


def MultiModelMean(M, vname, wgt_file=None, maxV=20):
    wgts = None
    if wgt_file is not None:
        with open(wgt_file[0], "rb") as f:
            wgts = pickle.load(f)
            print("Loaded weights:", wgts)

    sys.stdout.write(("{0:>%d} " % maxV).format(vname))
    sys.stdout.flush()

    # Based on the variable, we need to set up a space-time grid. So
    # first we need find a model with the variable and then check what
    # kind of variable it is.
    Ms = [m for m in M if vname in m.variables]
    if len(Ms) == 0:
        sys.stdout.write("\n")
        sys.stdout.flush()
        return
    if vname == "sftlf":
        v = Variable(
            filename=Ms[0].variables["sftlf"][0], variable_name="sftlf"
        ).convert("1")
    else:
        v = Ms[0].extractTimeSeries(vname)
    unit = v.unit
    shp = []

    # If the variable is temporal
    t = None
    t_bnd = None
    t0 = -1e20
    tf = 1e20
    if v.temporal:
        y0 = 1850 if vname == "nbp" else 1950
        yrs = (np.asarray(range(y0, 2016), dtype=float) - 1850) * 365
        yrs = (yrs[:, np.newaxis] + bnd_months[:12]).flatten()
        t_bnd = np.asarray([yrs[:-1], yrs[1:]]).T
        t = t_bnd.mean(axis=1)
        shp += [
            t.size,
        ]
        t0 = t_bnd[0, 0]
        tf = t_bnd[-1, 1]

    # If the variable is layered
    d = None
    d_bnd = None
    if v.layered:
        d_bnd = np.asarray(
            [
                [0.000, 0.025],
                [0.025, 0.065],
                [0.065, 0.125],
                [0.125, 0.21],
                [0.21, 0.33],
                [0.33, 0.49],
                [0.49, 0.69],
                [0.69, 0.9299999],
                [0.9299999, 1.21],
                [1.21, 1.53],
                [1.53, 1.89],
                [1.89, 2.29],
                [2.29, 2.745],
                [2.745, 3.285],
                [3.285, 3.925],
                [3.925, 4.665],
                [4.665, 5.505],
            ]
        )
        d = d_bnd.mean(axis=1)
        shp += [
            d.size,
        ]

    # If the variable is spatial
    lat = None
    lat_bnd = None
    lon = None
    lon_bnd = None
    if v.spatial:
        lat_bnd, lon_bnd, lat, lon = il.GlobalLatLonGrid(args.res)
        lat_bnd = np.asarray([lat_bnd[:-1], lat_bnd[1:]]).T
        lon_bnd = np.asarray([lon_bnd[:-1], lon_bnd[1:]]).T
        shp += [lat.size, lon.size]

    # Let's start summing it up
    data = np.zeros(shp)
    count = np.zeros(shp, dtype=int)
    sumw = 0.0
    for m in M:
        sys.stdout.write(".")
        sys.stdout.flush()
        w = wgts[m.name] if wgts is not None else 1.0

        vname0 = vname
        if vname == "cSoil" and "cSoilAbove1m" in m.variables:
            vname = "cSoilAbove1m"
        if vname == "burntArea" and "burntFractionAll" in m.variables:
            vname = "burntFractionAll"

        try:
            if vname == "sftlf":
                v = Variable(
                    filename=m.variables["sftlf"][0], variable_name="sftlf"
                ).convert("1")
            else:
                v = m.extractTimeSeries(vname, initial_time=t0, final_time=tf).convert(
                    unit
                )
        except:
            continue
        v = Interpolate(v, t, d, lat, lon)
        data += v.data * (v.mask == False) * w
        count += v.mask == False
        sumw += w
        vname = vname0

    # Take the mean and write it out
    with np.errstate(all="ignore"):
        data = np.ma.masked_array(data=(data / sumw), mask=(count == 0))
    with Dataset(os.path.join(args.build_dir[0], "%s.nc" % vname), mode="w") as dset:
        Variable(
            data=data,
            unit=unit,
            name=vname,
            time=t,
            time_bnds=t_bnd,
            lat=lat,
            lat_bnds=lat_bnd,
            lon=lon,
            lon_bnds=lon_bnd,
            depth=d,
            depth_bnds=d_bnd,
        ).toNetCDF4(dset)

    sys.stdout.write("\n")
    sys.stdout.flush()


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--model_root",
    dest="model_root",
    metavar="root",
    type=str,
    nargs=1,
    default=["./"],
    help="root at which to search for models",
)
parser.add_argument(
    "--build_dir",
    dest="build_dir",
    metavar="build_dir",
    type=str,
    nargs=1,
    default=["./_build"],
    help="path of where to save the output",
)
parser.add_argument(
    "--config",
    dest="config",
    metavar="config",
    type=str,
    nargs=1,
    help="path to configuration file to use",
)
parser.add_argument(
    "--wgt",
    dest="wgt",
    metavar="wgt",
    type=str,
    nargs=1,
    help="path to weight file to use",
)
parser.add_argument(
    "--models",
    dest="models",
    metavar="m",
    type=str,
    nargs="+",
    default=[],
    help="specify which models to run, list model names with no quotes and only separated by a space.",
)
parser.add_argument(
    "--vars",
    dest="vars",
    metavar="v",
    type=str,
    nargs="+",
    default=[],
    help="list of the variables to take the mean of",
)
parser.add_argument(
    "-q",
    "--quiet",
    dest="quiet",
    action="store_true",
    help="enable to silence screen output",
)
parser.add_argument(
    "--filter",
    dest="filter",
    metavar="filter",
    type=str,
    nargs=1,
    default=[""],
    help="a string which much be in the model filenames",
)
parser.add_argument(
    "--regex",
    dest="regex",
    metavar="regex",
    type=str,
    nargs=1,
    default=[""],
    help="a regular expression which filenames must conform to in order to be included",
)
parser.add_argument(
    "--model_setup",
    dest="model_setup",
    type=str,
    nargs="+",
    default=None,
    help="list files model setup information",
)
parser.add_argument(
    "-g",
    "--same_grid",
    dest="same_grid",
    action="store_true",
    help="enable if all models are on the same grid",
)
parser.add_argument(
    "-r",
    "--res",
    dest="res",
    type=float,
    default=1.0,
    help="enable if all models are on the same grid",
)

# Setup models to be used
args = parser.parse_args()
if not os.path.isdir(args.build_dir[0]):
    os.makedirs(args.build_dir[0])
if args.model_setup is None:
    M = r.InitializeModels(
        args.model_root[0],
        args.models,
        False,
        filter=args.filter[0],
        regex=args.regex[0],
        models_path=args.build_dir[0],
    )
else:
    M = r.ParseModelSetup(
        args.model_setup[0],
        args.models,
        False,
        filter=args.filter[0],
        models_path=args.build_dir[0],
    )
M = [m for m in M if "Mean" not in m.name]
if not args.quiet:
    print("\nTaking mean of the following models:\n")
    for m in M:
        print("{0:>20}".format(m.name))

# What variables do we need the means of?
Vs = []
if args.config is None:
    for m in M:
        Vs += [v for v in m.variables.keys() if v not in Vs]
else:
    if not args.quiet:
        print("\nParsing config file %s...\n" % args.config[0])
    S = Scoreboard(
        args.config[0], master=True, verbose=not args.quiet, build_dir=args.build_dir[0]
    )
    for c in S.list():
        vs = [
            c.variable,
        ]
        vs += c.alternate_vars
        if c.derived is not None:
            vs += [str(s) for s in sympify(c.derived).free_symbols]
        Vs += [v for v in vs if v not in Vs]
    if "co2" in Vs:
        Vs.pop(Vs.index("co2"))
    if "cSoilAbove1m" in Vs:
        Vs.pop(Vs.index("cSoilAbove1m"))
    if "burntFractionAll" in Vs:
        Vs.pop(Vs.index("burntFractionAll"))
if len(args.vars) > 0:
    Vs = args.vars
Vs = [
    "sftlf",
] + Vs
if not args.quiet:
    print("\nTaking the mean of [%s]...\n" % (", ".join(Vs)))

for vname in Vs:
    MultiModelMean(M, vname, args.wgt)
