#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import argparse
import logging
import os
import re
from datetime import timedelta
from getpass import getuser

from ParProcCo.job_controller import JobController
from ParProcCo.job_scheduling_information import JobResources
from ParProcCo.passthru_wrapper import PassThruWrapper
from ParProcCo.utils import get_token, set_up_wrapper


def create_parser() -> argparse.ArgumentParser:
    """
    $ ppc_cluster_submit program [--partition hpc00] [--token path/to/token/file]
     [--output cluster_output_dir] [--jobs 4] [--timeout 1h30m] --memory 4000M --cores 6
     -s 0.01 ... [input files]
    """
    parser = argparse.ArgumentParser(
        description="ParProcCo run script",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--partition",
        help="str: partition on which to run jobs",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--token", help="str: slurm token filepath", type=str
    )
    parser.add_argument("-o", "--output", help="str: cluster output file or directory")
    parser.add_argument(
        "--jobs",
        help="int: number of cluster jobs to split processing into",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--timeout",
        help="str: timeout for cluster jobs to finish - xxh[yym]",
        default="2h",
    )
    parser.add_argument(
        '--memory',
        help='maximum memory to use (e.g. 1024M, 4G, etc)',
        required=True
    )
    parser.add_argument(
        "--cores",
        help="int: number of cores to use per cluster job",
        type=int,
        required=True,
    )
    parser.add_argument(
        "-D", "--debug", help="show debugging information", action="store_true"
    )
    return parser


TIMEOUT_PATTERN = re.compile(
    r"^((?P<hours>\d+)h)?((?P<minutes>\d+)m)?$", re.I
)  # @UndefinedVariable


def parse_timeout(timeout: str) -> timedelta:
    mo = TIMEOUT_PATTERN.match(timeout.strip())
    if not mo:
        raise ValueError(f"Could not parse {timeout} as time interval")
    to_dict = {
        k: int(v) for k, v in mo.groupdict().items() if v is not None
    }  # filter out None-valued items
    logging.debug("Parsed time as %s", to_dict)
    return timedelta(**to_dict)


def parse_memory(memory: str) -> int:
    s = memory[-1]
    if s not in ('M', 'G'):
        raise ValueError('Memory specified must end with M or G')
    try:
        m = int(memory[:-1])
    except Exception as e:
        raise ValueError('Memory specified must start with a number') from e

    if m <= 0:
        raise ValueError('Memory specified must be greater than 0')
    if s == 'M' and m < 512:
        logging.warning('Memory specified is recommended to be over 512M')
    if s == 'G':
        if m > 64:
            logging.warning('Memory specified (>64G) seems to be excessive')
        m = m*1024
    return m


def run_ppc(args: argparse.Namespace, script_args: list[str]) -> None:
    """
    Run JobController
    """
    from ParProcCo.utils import load_cfg

    cfg = load_cfg()
    url = cfg.url
    token = get_token(args.token)
    user = getuser()

    extra_properties = None
    if cfg.extra_property_envs:
        extra_properties = {}
        logging.debug("Extra job properties:")
        for k,e in cfg.extra_property_envs:
            v = os.getenv(e)
            logging.debug("\t%s: %s", k, v)
            if v:
                extra_properties[k] = v

    timeout = parse_timeout(args.timeout)
    memory = parse_memory(args.memory)
    logging.info("Running with timeout %s and memory limit %dM", timeout, memory)

    if not script_args:
        raise ValueError("No script and any of its arguments given")

    program = script_args[0]
    if args.jobs >= 1:
        wrapper = set_up_wrapper(cfg, program)
        if args.jobs == 1:
            wrapper = PassThruWrapper(wrapper)
    else:
        raise ValueError(f"Number of jobs must be one or more, given {args.jobs}")
    output = wrapper.get_output(args.output, script_args[1:])
    wrapper_args = wrapper.get_args(script_args, args.debug)
    jc = JobController(
        url,
        wrapper,
        output,
        args.partition,
        user,
        token,
        timeout,
    )
    jc.run(args.jobs, wrapper_args, "PPC-" + program,
           JobResources(memory=memory, cpu_cores=args.cores,
                        extra_properties=extra_properties))
    print("Jobs completed")


if __name__ == "__main__":
    args, script_args = create_parser().parse_known_args()
    logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
    if args.debug:
        logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(module)s:%(funcName)s:%(lineno)d:%(message)s")
    run_ppc(args, script_args)
