import os
import yaml
import datetime
import logging
from archngv.app.utils import load_ngv_manifest

LOGS_DIR = "logs"
MORPH_DIR = "morphologies"
BIONAME = os.path.realpath(config.get("bioname", "bioname"))


def _get_seed(manifest_path):
    """Get the seed from the manifest"""
    try:
        return load_ngv_manifest(manifest_path)["common"]["seed"]
    except KeyError:
        return "0"


def bioname_path(filename):
    return os.path.join(BIONAME, filename)


MANIFEST_PATH = bioname_path("MANIFEST.yaml")
SEED = _get_seed(MANIFEST_PATH)


def log_path(name):
    timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
    return os.path.abspath(os.path.join(LOGS_DIR, "%s.%s.log" % (name, timestamp)))


def escape_single_quotes(value):
    return value.replace("'", "'\\''")


def get_log_level_for_cli(level):
    levels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
    level = logging.getLevelName(level)
    assert level in levels, f"Unknown log level {level}. Allowed levels: {levels}"
    idx = levels.index(level)
    cli_level = "".join(["v"] * idx)
    if cli_level:
        return f"-{cli_level}"
    return ""


MANIFEST = load_ngv_manifest(MANIFEST_PATH)
COMMON = MANIFEST["common"]

ATLAS = COMMON["atlas"]
ATLAS_CACHE_DIR = ".atlas"
VASCULATURE_MORPHOLOGY = COMMON["vasculature"]
TETMESH = MANIFEST["tetrahedral_mesh"]
PARALLEL = COMMON.get("parallel", True)
LOG_LEVEL = get_log_level_for_cli(COMMON.get("log_level", "WARNING"))


def refinement_subdividing_steps():
    """Return the refinement_subdividing_steps from config file if exist.
    otherwise return 1.
    """
    if "refinement_subdividing_steps" in TETMESH:
        return int(TETMESH["refinement_subdividing_steps"])

    return 1


def get_node_population_name(population_type, default=None):
    if "node_population_name" in COMMON:
        return COMMON["node_population_name"].get(population_type, default=default)

    return default


NODES_ASTROCYTE_NAME = get_node_population_name("astrocyte", default="astrocytes")
NODES_VASCULATURE_NAME = get_node_population_name("vasculature", default="vasculature")


def get_edge_population_name(population_type, default=None):
    if "edge_population_name" in COMMON:
        return COMMON["edge_population_name"].get(population_type, default=default)

    return default


EDGES_SYNAPSE_ASTROCYTE_NAME = get_edge_population_name("synapse_astrocyte", default="neuroglial")
EDGES_ENDFOOT_NAME = get_edge_population_name("endfoot", default="gliovascular")
EDGES_GLIALGLIAL_NAME = get_edge_population_name("glialglial", default="glialglial")

BASE_CIRCUIT_SONATA = COMMON["base_circuit_sonata"]
BASE_CIRCUIT_CELLS = COMMON["base_circuit_cells"]
BASE_CIRCUIT_CONNECTOME = COMMON["base_circuit_connectome"]
BASE_SPATIAL_SYNAPSE_INDEX_DIR = COMMON["base_spatial_synapse_index_dir"]
BUILDER_RECIPE = bioname_path("astrocyte_gap_junction_recipe.xml")

SPACK_MODULEPATH = "/gpfs/bbp.cscs.ch/ssd/apps/bsd/modules/_meta"

MODULES = {
    "glial_gap_junctions": (SPACK_MODULEPATH, ["archive/2022-03", "touchdetector/5.6.1"]),
    "synthesis": (SPACK_MODULEPATH, ["archive/2023-08", "spatial-index/2.1.0", "py-mpi4py/3.1.4"]),
    "refine_tetrahedral": (SPACK_MODULEPATH, ["archive/2023-05", "gmsh/4.10.3"]),
    "build_tetrahedral": (SPACK_MODULEPATH, ["archive/2023-05", "gmsh/4.10.3"]),
}


TOUCHES_DIR = "connectome/touches"


def propagate_cluster_default_values():
    # propagate the missing options from the default
    slurm_envs = set(cluster_config.keys()) - {"__default__"}
    default = cluster_config["__default__"]
    for slurm_env in slurm_envs:
        slurm_dict = cluster_config[slurm_env]
        missing_keys = set(default.keys()) - set(slurm_dict.keys())
        slurm_dict.update({key: default[key] for key in missing_keys})


def salloc_cmd(slurm_env):
    if (slurm_env is not None) and cluster_config:
        if slurm_env not in cluster_config:
            slurm_env = "__default__"

        cfg = cluster_config[slurm_env]
        options = (
            f"-C {cfg['constraint']} -A {cfg['account']} -N {cfg['nodes']}"
            + f" -p {cfg['partition']} -J {cfg['jobname']} -t {cfg['time']}"
        )
        if "tasks" in cfg:
            options += f' -n {cfg["tasks"]}'
        if "cpus-per-task" in cfg:
            options += f' --cpus-per-task {cfg["cpus-per-task"]}'
        if "exclusive" in cfg:
            options += " --exclusive"
        if "mem" in cfg:
            options += f' --mem {cfg["mem"]}'
        return f"salloc {options}"
    return ""


def modules_cmd(module_env):
    if module_env is not None:
        modulepath, modules = MODULES[module_env]
        return " && ".join(
            [
                "export MODULEPATH=%s" % modulepath,
                ". /etc/profile.d/modules.sh",
                "module purge",
                "module load %s" % " ".join(modules),
                "echo MODULEPATH=%s" % modulepath,
                "module list",
                " ",
            ]
        )
    return ""


def run_cmd(cmd, dump_log=False, module_env=None, slurm_env=None):
    """
    Args:
        cmd (array): shell command to run with its arguments
        dump_log (bool): whether to produce a log file
        module_env (str): modules specified for this command in MODULES
        slurm_env (str): slurm environment specified in cluster config
    """
    command = " ".join(map(str, cmd))
    result = f"{modules_cmd(module_env)} {salloc_cmd(slurm_env)} {command}"

    if dump_log:
        result += " 2>&1 | tee {log}"
    return result


propagate_cluster_default_values()  #  cluster info propagation does not work so doing it here


rule all:
    input:
        "microdomains.h5",
        "endfeet_meshes.h5",
        MORPH_DIR + "/_DONE",
        TOUCHES_DIR + "/_SUCCESS",
        "ngv_config.json",
        f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
        f"sonata/networks/edges/{EDGES_SYNAPSE_ASTROCYTE_NAME}/edges.h5",
        f"sonata/networks/edges/{EDGES_GLIALGLIAL_NAME}/edges.h5",
        f"sonata/networks/edges/{EDGES_ENDFOOT_NAME}/edges.h5",
        "ngv_refined_tetrahedral_mesh.msh",


rule ngv_config:
    output:
        "ngv_config.json",
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} config-file",
                "--bioname",
                BIONAME,
                "--output",
                "{output}",
            ]
        )


rule sonata_vasculature:
    input:
        VASCULATURE_MORPHOLOGY,
    output:
        f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
    log:
        log_path("sonata_vasculature"),
    shell:
        run_cmd([f"vascpy morphology-to-sonata", "{input} {output}"], dump_log=True)


rule cell_placement:
    input:
        f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
    output:
        "sonata.tmp/nodes/glia.somata.h5",
    log:
        log_path("cell_placement"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} cell-placement",
                f'--config {bioname_path("MANIFEST.yaml")}',
                f"--atlas {ATLAS}",
                f"--atlas-cache {ATLAS_CACHE_DIR}",
                f"--vasculature {input}",
                f"--population-name {NODES_ASTROCYTE_NAME}",
                "--output {output}",
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule assign_emodels:
    input:
        "sonata.tmp/nodes/glia.somata.h5",
    output:
        "sonata.tmp/nodes/glia.emodels.h5",
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} assign-emodels",
                "--input {input}",
                "--output {output}",
                f'--hoc {MANIFEST["assign_emodels"]["hoc_template"]}',
            ]
        )


rule finalize_astrocytes:
    input:
        somata="sonata.tmp/nodes/glia.somata.h5",
        emodels="sonata.tmp/nodes/glia.emodels.h5",
    output:
        f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
    log:
        log_path("finalize_astrocytes"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} finalize-astrocytes",
                "--somata-file {input[somata]}",
                "--emodels-file {input[emodels]}",
                "--output {output}",
            ],
            dump_log=True,
        )


rule microdomains:
    input:
        f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
    output:
        "microdomains.h5",
    log:
        log_path("microdomains"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} microdomains",
                f'--config {bioname_path("MANIFEST.yaml")}',
                "--astrocytes {input}",
                f"--atlas {ATLAS}",
                f"--atlas-cache {ATLAS_CACHE_DIR}",
                "--output-file-path {output}",
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule gliovascular_connectivity:
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        microdomains="microdomains.h5",
        vasculature=f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
    output:
        "sonata.tmp/edges/gliovascular.connectivity.h5",
    log:
        log_path("gliovascular_connectivity"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} gliovascular-connectivity",
                f'--config {bioname_path("MANIFEST.yaml")}',
                "--astrocytes {input[astrocytes]}",
                "--microdomains {input[microdomains]}",
                "--vasculature {input[vasculature]}",
                f"--seed {SEED}",
                f"--population-name {EDGES_ENDFOOT_NAME}",
                "--output {output}",
            ],
            dump_log=True,
        )


rule neuroglial_connectivity:
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        microdomains="microdomains.h5",
    output:
        "sonata.tmp/edges/neuroglial.connectivity.h5",
    log:
        log_path("neuroglial_connectivity"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} neuroglial-connectivity",
                f"--neurons-path {BASE_CIRCUIT_CELLS}",
                "--astrocytes-path {input[astrocytes]}",
                "--microdomains-path {input[microdomains]}",
                f"--neuronal-connectivity-path {BASE_CIRCUIT_CONNECTOME}",
                f"--spatial-synapse-index-dir {BASE_SPATIAL_SYNAPSE_INDEX_DIR}",
                "--output-path {output}",
                f"--population-name {EDGES_SYNAPSE_ASTROCYTE_NAME}",
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule glial_gap_junctions:
    message:
        "Detect touches between astrocytes"
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        morphologies=MORPH_DIR + "/_DONE",
    output:
        touch(TOUCHES_DIR + "/_SUCCESS"),
    log:
        log_path("touchdetector"),
    shell:
        run_cmd(
            [
                'srun --mpi pmi2 sh -c "'
                + " ".join(
                    [
                        "touchdetector",
                                f"--output {TOUCHES_DIR}",
                                "--save-state",
                                f"--from {{input[astrocytes]}} {NODES_ASTROCYTE_NAME}",
                                f"--to {{input[astrocytes]}} {NODES_ASTROCYTE_NAME}",
                                BUILDER_RECIPE,
                                MORPH_DIR,
                            ]
                        )
                + '"'
            ],
            dump_log=True,
            module_env="glial_gap_junctions",
            slurm_env="glial_gap_junctions",
        )


rule glialglial_connectivity:
    message:
        " Extract glial glial connectivity from touches"
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        touches=TOUCHES_DIR + "/_SUCCESS",
    output:
        glialglial_connectivity=f"sonata/networks/edges/{EDGES_GLIALGLIAL_NAME}/edges.h5",
    log:
        log_path("glialglial_connectivity"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} glialglial-connectivity",
                "--astrocytes {input[astrocytes]}",
                f"--touches-dir {TOUCHES_DIR}",
                f"--population-name {EDGES_GLIALGLIAL_NAME}",
                "--output-connectivity {output[glialglial_connectivity]}",
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule endfeet_area:
    input:
        gliovascular_connectivity="sonata.tmp/edges/gliovascular.connectivity.h5",
    output:
        "endfeet_meshes.h5",
    log:
        log_path("endfeet_area"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} endfeet-area",
                f'--config-path {bioname_path("MANIFEST.yaml")}',
                f'--vasculature-mesh-path {COMMON["vasculature_mesh"]}',
                "--gliovascular-connectivity-path {input[gliovascular_connectivity]}",
                "--output-path {output}",
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule synthesis:
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        microdomains="microdomains.h5",
        gliovascular_connectivity="sonata.tmp/edges/gliovascular.connectivity.h5",
        neuroglial_connectivity="sonata.tmp/edges/neuroglial.connectivity.h5",
        endfeet_meshes="endfeet_meshes.h5",
    output:
        touch(MORPH_DIR + "/_DONE"),
    log:
        log_path("synthesis"),
    shell:
        run_cmd(
            [
                "srun" if PARALLEL else "",
                f"ngv {LOG_LEVEL} synthesis",
                f'--config-path {bioname_path("MANIFEST.yaml")}',
                f'--tns-distributions-path {bioname_path("tns_distributions.json")}',
                f'--tns-parameters-path {bioname_path("tns_parameters.json")}',
                f'--tns-context-path {bioname_path("tns_context.json")}',
                f'--er-data-path {bioname_path("er_data.json")}',
                "--astrocytes-path {input[astrocytes]}",
                "--microdomains-path {input[microdomains]}",
                "--gliovascular-connectivity-path {input[gliovascular_connectivity]}",
                "--neuroglial-connectivity-path {input[neuroglial_connectivity]}",
                "--endfeet-meshes-path {input[endfeet_meshes]}",
                f"--neuronal-connectivity-path {BASE_CIRCUIT_CONNECTOME}",
                f"--out-morph-dir {MORPH_DIR}",
                ("--parallel" if PARALLEL else ""),
                f"--seed {SEED}",
            ],
            dump_log=True,
            module_env="synthesis" if PARALLEL else None,
            slurm_env="synthesis" if PARALLEL else None,
        )


rule finalize_gliovascular_connectivity:
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        connectivity="sonata.tmp/edges/gliovascular.connectivity.h5",
        endfeet_meshes="endfeet_meshes.h5",
        morphologies=MORPH_DIR + "/_DONE",
        vasculature_sonata=f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
    output:
        f"sonata/networks/edges/{EDGES_ENDFOOT_NAME}/edges.h5",
    log:
        log_path("finalize_gliovascular_connectivity"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} attach-endfeet-info-to-gliovascular-connectivity",
                "--input-file {input[connectivity]}",
                "--output-file {output}",
                "--astrocytes {input[astrocytes]}",
                "--endfeet-meshes-path {input[endfeet_meshes]}",
                "--vasculature-sonata {input[vasculature_sonata]}",
                f"--morph-dir {MORPH_DIR}",
                ("--parallel" if PARALLEL else ""),
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule finalize_neuroglial_connectivity:
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        microdomains="microdomains.h5",
        connectivity="sonata.tmp/edges/neuroglial.connectivity.h5",
        morphologies=MORPH_DIR + "/_DONE",
    output:
        f"sonata/networks/edges/{EDGES_SYNAPSE_ASTROCYTE_NAME}/edges.h5",
    log:
        log_path("finalize_neuroglial_connectivity"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} attach-morphology-info-to-neuroglial-connectivity",
                "--input-file-path {input[connectivity]}",
                "--output-file-path {output}",
                "--astrocytes-path {input[astrocytes]}",
                "--microdomains-path {input[microdomains]}",
                f"--synaptic-data-path {BASE_CIRCUIT_CONNECTOME}",
                f"--morph-dir {MORPH_DIR}",
                ("--parallel" if PARALLEL else ""),
                f"--seed {SEED}",
            ],
            dump_log=True,
        )


rule prepare_tetrahedral:
    # generates a mesh file and a gmsh script for the next step
    input:
        astrocytes=f"sonata/networks/nodes/{NODES_ASTROCYTE_NAME}/nodes.h5",
        vasculature=f"sonata/networks/nodes/{NODES_VASCULATURE_NAME}/nodes.h5",
        neuron=f"{BASE_CIRCUIT_CELLS}",
    output:
        mesh="ngv_prepared_tetrahedral_mesh.stl",
        script="ngv_prepared_tetrahedral_mesh.geo",
    log:
        log_path("prepare_tetrahedral"),
    shell:
        run_cmd(
            [
                f"ngv {LOG_LEVEL} refined-surface-mesh",
                f'--config-path {bioname_path("MANIFEST.yaml")}',
                "--astrocytes-path {input[astrocytes]}",
                "--neurons-path {input[neuron]}",
                "--vasculature-path {input[vasculature]}",
                "--output-path {output.mesh}",
            ],
            dump_log=True,
        )


rule build_tetrahedral:
    input:
        mesh_file="ngv_prepared_tetrahedral_mesh.stl",
        script_file="ngv_prepared_tetrahedral_mesh.geo",
    output:
        "ngv_tetrahedral_mesh.msh",
    log:
        log_path("build_tetrahedral"),
    shell:
        run_cmd(
            [
                "gmsh",
                "{input.script_file}",
                "-3",  #  Perform mesh generation from 2d (surface) to 3d (tetrahedral).
                "-o {output}",
                f"-algo initial3d",
            ],
            module_env="build_tetrahedral",
        )


rule refine_tetrahedral:
    input:
        "ngv_tetrahedral_mesh.msh",
    output:
        "ngv_refined_tetrahedral_mesh.msh",
    log:
        log_path("refine_tetrahedral"),
    shell:
        run_cmd(
            [
                "cp -v {input} tmp.msh && ",
                f"for (( c=1; c<={refinement_subdividing_steps()}; c++ )); ",
                "do gmsh -refine tmp.msh -o tmp.msh; done && ",
                "mv -v tmp.msh {output}",
            ],
            module_env="refine_tetrahedral",
        )
