#!/usr/bin/env python3
import argparse
import datetime
import logging
import os
import numpy as np

# TODO consider using real python setup.py utilities
# For now, this allows importing upsp.* from both the
# source-tree and the install tree.
import site
site.addsitedir(os.path.join(os.path.dirname(__file__), "..", "python"))
import warnings  # noqa
warnings.filterwarnings("ignore")
import upsp.processing.context as context  # noqa
import upsp.processing.tree as tree # noqa
warnings.filterwarnings("default")


_NAS_MAX_WALLTIME_BY_QUEUE = {
    "normal": "08:00:00",
    "long": "08:00:00",
    "devel": "02:00:00",
}


def qsub_args(ctx, step_name, job_name, job_walltime):
    nas = ctx.ctx['nas'][step_name]
    now_iso_str = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
    stdout = tree._app_logs_path(
        ctx.ctx, step_name, 'jobs', "%s-%s.stdout" % (now_iso_str, job_name)
    )
    stderr = tree._app_logs_path(
        ctx.ctx, step_name, 'jobs', "%s-%s.stderr" % (now_iso_str, job_name)
    )
    return [
        '-N', job_name,
        '-q', nas['queue'],
        '-W', 'group_list=%s' % nas['charge_group'],
        '-l', 'select=%s:model=%s,walltime=%s' % (
            nas['number_nodes'], nas['node_model'], job_walltime
        ),
        '-o %s' % stdout,
        '-e %s' % stderr
    ]


def _timedelta_from_string(s):
    hours, minutes, seconds = [int(f) for f in s.split(":")]
    return np.sum([
        np.timedelta64(hours, 'h'),
        np.timedelta64(minutes, 'm'),
        np.timedelta64(seconds, 's')
    ])


def _timedelta_to_string(td: np.timedelta64):
    total_seconds = int(td.item().total_seconds())
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return "%02d:%02d:%02d" % (hours, minutes, seconds)


def _allocate_tasks_to_jobs(_max_td_per_job, _td_per_task, n_tasks):
    n_tasks_per_full_job = _max_td_per_job // _td_per_task
    n_full_jobs, n_rem_tasks = divmod(n_tasks, n_tasks_per_full_job)
    jobs_tasks = np.reshape(
        np.arange(n_tasks_per_full_job * n_full_jobs),
        (n_full_jobs, n_tasks_per_full_job)
    ).tolist()
    td_per_full_job = _td_per_task * n_tasks_per_full_job
    jobs_td = [td_per_full_job] * n_full_jobs
    if n_rem_tasks > 0:
        rem_tasks = np.arange(n_full_jobs * n_tasks_per_full_job, n_tasks).tolist()
        td_per_rem_job = _td_per_task * n_rem_tasks
        jobs_tasks.append(rem_tasks)
        jobs_td.append(td_per_rem_job)
    return jobs_td, jobs_tasks


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    ap = argparse.ArgumentParser(
        description="Compute args for multiple PBS qsub's to run given datapoints"
    )
    ap.add_argument("pipeline_dir", help="Pipeline directory")
    ap.add_argument("step_name", help="Pipeline step name")
    ap.add_argument("datapoints", nargs='+', help="Datapoints (space-separated list)")
    args = ap.parse_args()

    ctx = context.Pipeline(args.pipeline_dir)
    wt_per_dp = ctx.ctx['nas'][args.step_name]['wall_time']
    queue = ctx.ctx['nas'][args.step_name]['queue']
    launcher_type = ctx.ctx['nas'][args.step_name]['launcher']
    number_nodes = ctx.ctx['nas'][args.step_name]['number_nodes']
    datapoints = sorted(set(args.datapoints))
    number_datapoints = len(datapoints)

    td_per_dp = _timedelta_from_string(wt_per_dp)
    max_td_per_job = _timedelta_from_string(_NAS_MAX_WALLTIME_BY_QUEUE[queue])

    if launcher_type == "parallel":
        # Group datapoints into batches of 'number_nodes' datapoints.
        # Each batch will (in theory) take a single datapoint's walltime to run.
        # Allocate batches to jobs, then map that back to a list of datapoints per job.
        td_per_batch = td_per_dp
        n_dp_per_batch = number_nodes
        n_full_batches, dp_rem = divmod(number_datapoints, n_dp_per_batch)
        n_batches = n_full_batches + int(dp_rem > 0)
        batches_dps = np.reshape(
            np.arange(n_dp_per_batch * n_full_batches),
            (n_full_batches, n_dp_per_batch)
        ).tolist()
        if dp_rem > 0:
            rem_dps = (np.arange(dp_rem) + (n_full_batches * n_dp_per_batch)).tolist()
            batches_dps.append(rem_dps)
        jobs_td, jobs_batches = _allocate_tasks_to_jobs(
            max_td_per_job, td_per_batch, n_batches
        )
        jobs_dps = [
            np.concatenate([batches_dps[ib] for ib in batch]).tolist()
            for batch in jobs_batches
        ]
    elif launcher_type == "serial":
        # Parallelization from 'number_nodes' is already accounted for in
        # the per-datapoint walltime. Allocate datapoints directly to jobs.
        jobs_td, jobs_dps = _allocate_tasks_to_jobs(
            max_td_per_job, td_per_dp, number_datapoints
        )

    for idx, (td, dps) in enumerate(zip(jobs_td, jobs_dps)):
        wt = _timedelta_to_string(td)
        name = 'j%d' % (idx)
        print(' '.join(qsub_args(ctx, args.step_name, name, wt)))
        print(' '.join([datapoints[ii] for ii in dps]))
