# ---- begin snakebids boilerplate ----------------------------------------------

from snakebids import bids, generate_inputs, get_wildcard_constraints


configfile: workflow.source_path("../config/snakebids.yml")


# Get input wildcards
inputs_raw = generate_inputs(
    bids_dir=config["bids_dir"],
    pybids_inputs=config["pybids_inputs"],
    pybids_database_dir=config.get("pybids_db_dir"),
    pybids_reset_database=config.get("pybids_db_reset"),
    derivatives=config.get("derivatives", None),
    participant_label=config.get("participant_label", None),
    exclude_participant_label=config.get("exclude_participant_label", None),
)
inputs_afids = generate_inputs(
    bids_dir=config["afids_dir"],
    pybids_inputs=config["pybids_inputs_afids"],
    pybids_database_dir=config.get("pybids_db_dir"),
    pybids_reset_database=config.get("pybids_db_reset"),
    pybids_config=['bids', 'derivatives'],  # Need to allow desc etc.
    derivatives=config.get("derivatives", None),
    participant_label=config.get("participant_label", None),
    exclude_participant_label=config.get("exclude_participant_label", None),
)
if any([config["validation_bids_dir"], config["validation_afids_dir"]]) and not all(
    [config["validation_bids_dir"], config["validation_afids_dir"]]
):
    raise ValueError(
        "If a validation bids directory is defined, the corresponding validation afids "
        "directory must be defined."
    )
inputs_raw_validation = (
    generate_inputs(
        bids_dir=config["validation_bids_dir"],
        pybids_inputs=config["pybids_inputs"],
        pybids_database_dir=config.get("pybids_db_dir"),
        pybids_reset_database=config.get("pybids_db_reset"),
        derivatives=config.get("derivatives", None),
    )
    if config["validation_bids_dir"]
    else None
)
inputs_afids_validation = (
    generate_inputs(
        bids_dir=config["validation_afids_dir"],
        pybids_inputs=config["pybids_inputs_afids"],
        pybids_database_dir=config.get("pybids_db_dir"),
        pybids_reset_database=config.get("pybids_db_reset"),
        pybids_config=['bids', 'derivatives'],  # Need to allow desc etc.
        derivatives=config.get("derivatives", None),
    )
    if config["validation_afids_dir"]
    else None
)


def subject_in_path(subject_id, root_path):
    return f"sub-{subject_id}" in {
        path.name for path in Path(root_path).iterdir() if path.is_dir()
    }


def choose_correct_raw_dataset(wildcards):
    if not inputs_raw_validation:
        return inputs_raw
    if subject_in_path(wildcards["subject"], config["bids_dir"]):
        return inputs_raw
    if subject_in_path(wildcards["subject"], config["validation_bids_dir"]):
        return inputs_raw_validation
    raise ValueError(f"Raw dataset for sub-{wildcards['subject']} not found.")


def choose_correct_afids_dataset(wildcards):
    if not inputs_afids_validation:
        return inputs_afids
    if subject_in_path(wildcards["subject"], config["bids_dir"]):
        return inputs_afids
    if subject_in_path(wildcards["subject"], config["validation_bids_dir"]):
        return inputs_afids_validation
    raise ValueError(f"AFIDs dataset for sub-{wildcards['subject']} not found.")


# this adds constraints to the bids naming
wildcard_constraints:
    **get_wildcard_constraints(config["pybids_inputs"]),


# ---- end snakebids boilerplate ------------------------------------------------


rule extract_afid:
    input:
        fcsv=lambda wildcards: choose_correct_afids_dataset(wildcards)["afids"].path,
    output:
        txt=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}.txt",
            **inputs_raw["T1w"].wildcards
        ),
    log:
        bids(
            root="logs",
            suffix="landmark.log",
            afid="{afid}",
            **inputs_raw["T1w"].wildcards
        ),
    script:
        './scripts/extract_afids.py'


rule gen_sphere:
    input:
        t1w=lambda wildcards: choose_correct_raw_dataset(wildcards)["T1w"].path,
        txt=rules.extract_afid.output.txt,
    output:
        sphere=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.t1w} -scale 0 -landmarks-to-spheres {input.txt} 1 -o {output.sphere}"


rule gen_prob:
    input:
        sphere=rules.gen_sphere.output.sphere,
    output:
        prob=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_prob.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.sphere} -sdt -o {output.prob}"


rule gen_norm:
    input:
        prob=rules.gen_prob.output.prob,
    output:
        norm=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_prob_norm_-{k}.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.prob} -scale {k} -exp -o {output.norm}"


rule gen_mask:
    input:
        prob=rules.gen_prob.output.prob,
    output:
        bin=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_bin.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.prob} -threshold 0 10 1 0 -o {output.bin}"


rule gen_patches:
    input:
        t1w=lambda wildcards: choose_correct_raw_dataset(wildcards)["T1w"].path,
        prob=rules.gen_prob.output.prob,
        mask=rules.gen_mask.output.bin,
    output:
        patch=temp(
            bids(
                root=str(Path(config["output_dir"]) / "c3d"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_patch.dat",
                **inputs_raw["T1w"].wildcards
            )
        ),
    params:
        radius_arg=lambda wildcards: "x".join([wildcards["radius"] for _ in range(3)]),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.t1w} {input.prob} {input.mask} "
        "-xpa {wildcards.num_augment} {wildcards.angle_stdev} "
        "-xp {output.patch} {params.radius_arg} {wildcards.frequency}"


rule cat_patches:
    input:
        patch=inputs_raw["T1w"].expand(
            rules.gen_patches.output.patch, allow_missing=True
        ),
    output:
        combined_patch=protected(
            bids(
                root=str(Path(config["output_dir"]) / "c3d"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_allpatches.dat",
            )
        ),
    shell:
        "cat {input.patch} > {output.combined_patch}"


rule cat_patches_validation:
    input:
        patch=inputs_raw_validation["T1w"].expand(
            rules.gen_patches.output.patch, allow_missing=True
        ) if inputs_raw_validation else [],
    output:
        combined_patch=protected(
            bids(
                root=str(Path(config["output_dir"]) / "c3d"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_allvalidationpatches.dat",
            )
        ),
    shell:
        "cat {input.patch} > {output.combined_patch}"


rule train_model:
    input:
        combined_patch=rules.cat_patches.output.combined_patch,
        combined_patch_validation=rules.cat_patches_validation.output.combined_patch
        if inputs_raw_validation
        else [],
    output:
        model=directory(
            bids(
                root=str(Path(config["output_dir"]) / "afids-cnn-train"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_cnn.model",
            )
        ),
        history=bids(
            root=str(Path(config["output_dir"]) / "afids-cnn-train"),
            desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
            suffix="afid-{afid}_loss.csv",
        ),
    params:
        num_channels=config["num_channels"],
        epochs=config["epochs"],
        steps_per_epoch=config["steps_per_epoch"],
        loss_fn=config["loss_fn"],
        optimizer=config["optimizer"],
        metrics=" ".join(config["metrics"]),
        validation_steps=config["validation_steps"],
        validation_arg="--validation_patches_path" if inputs_raw_validation else "",
        do_early_stopping="--do_early_stopping" if config["do_early_stopping"] else "",
    shell:
        "auto_afids_cnn_train {params.num_channels} {wildcards.radius} "
        "{input.combined_patch} {output.model} --loss_out_path {output.history} "
        "--epochs {params.epochs} --steps_per_epoch {params.steps_per_epoch} "
        "--loss_fn {params.loss_fn} --optimizer {params.optimizer} "
        "--metrics {params.metrics} --validation_steps {params.validation_steps} "
        "{params.do_early_stopping} "
        "{params.validation_arg} {input.combined_patch_validation} "


rule all:
    input:
        models=expand(
            rules.train_model.output.model,
            afid=[f"{afid:02}" for afid in range(1, 33)],
            num_augment=config["num_augment"],
            angle_stdev=config["angle_stdev"],
            radius=config["radius"],
            frequency=config["frequency"],
        ),
    output:
        combined_model=bids(
            root=str(Path(config["output_dir"]) / "afids-cnn-model"),
            suffix="model-combined.afidsmodel",
        ),
    params:
        radius=config["radius"],
    resources:
        script=str(Path(workflow.basedir) / "scripts" / "assemble_models.py"),
    default_target: True
    shell:
        "python3 {resources.script} {input.models} {params.radius} "
        "{output.combined_model}"
