import os
import json
import yaml
import copy

from customize_pipeline import get_region_id, check_rule_existence, get_merge_rule_name, get_var_path_map


configfile: "./config.yaml"

# Load default pipeline
default_snakefile = "../snakefile"
module default_pipeline:
    snakefile: default_snakefile
    config: config

# Import rules from default pipeline
use rule * from default_pipeline

WORKING_DIR = default_pipeline.WORKING_DIR
LOG_DIR = default_pipeline.LOG_DIR
FORGE_CONFIG = default_pipeline.FORGE_CONFIG
REPO_PATH = default_pipeline.REPO_PATH
NEXUS_ATLAS_ENV = default_pipeline.NEXUS_ATLAS_ENV
NEXUS_ATLAS_ORG = default_pipeline.NEXUS_ATLAS_ORG
NEXUS_ATLAS_PROJ = default_pipeline.NEXUS_ATLAS_PROJ
PUSH_DATASET_CONFIG = default_pipeline.PUSH_DATASET_CONFIG_FILE
NEXUS_DESTINATION_ENV = default_pipeline.NEXUS_DESTINATION_ENV
NEXUS_DESTINATION_ORG = default_pipeline.NEXUS_DESTINATION_ORG
NEXUS_DESTINATION_PROJ = default_pipeline.NEXUS_DESTINATION_PROJ
IS_PROD_ENV = default_pipeline.IS_PROD_ENV
CELL_COMPOSITION_NAME = default_pipeline.CELL_COMPOSITION_NAME
CELL_COMPOSITION_SUMMARY_NAME = default_pipeline.CELL_COMPOSITION_SUMMARY_NAME
CELL_COMPOSITION_VOLUME_NAME = default_pipeline.CELL_COMPOSITION_VOLUME_NAME
METADATA_PATH = default_pipeline.METADATA_PATH

nexus_dryrun = default_pipeline.nexus_dryrun
atlas_release_id = default_pipeline.atlas_release_id
atlas_release_rev = default_pipeline.atlas_release_rev
cell_composition_id = default_pipeline.cell_composition_id
brain_region_id = default_pipeline.brain_region_id
root_region_name = default_pipeline.root_region_name

with open(os.path.join(REPO_PATH, "customize_pipeline/available_vars.yaml"), "r") as vars_file:
    AVAILABLE_VARS = yaml.safe_load(vars_file.read().strip())
var_path_map = get_var_path_map(AVAILABLE_VARS["input"], PUSH_DATASET_CONFIG)

user_config_path = config["USER_CONFIG"]
target_rule = config["TARGET_RULE"]

user_config = json.load(open(user_config_path))
user_rules = user_config["rules"]
for user_rule in user_rules:
    rule_name = user_rule["rule"]
    default_rule = workflow.get_rule(rule_name)
    check_rule_existence(default_rule, rule_name)

    default_output = getattr(default_rule, "output")
    default_output_dir = getattr(default_output, "dir", default_output)
    default_output_file = getattr(default_output, "file", None)
    merged_output = f"{default_output_dir}_merged"

    custom_regions_rules = []
    for region_customization in user_rule["execute"]:
        region_id = get_region_id(region_customization["brainRegion"])
        region_rule_name = f"{rule_name}_region{region_id}"
        region_rule_cli = region_customization["CLI"]
        region_customization['output_dir'] = region_customization['output_dir'].format(WORKING_DIR=WORKING_DIR)

        args_vars = re.findall("{(.*?)}", region_rule_cli["args"])
        args_vars = [args_var.split(".")[1] for args_var in args_vars]
        region_rule_vars = copy.deepcopy(var_path_map)
        for var in var_path_map:
            if var not in args_vars:
                region_rule_vars.pop(var)

        # Define region-specific rule for the region provided in the user-configuration
        rule:
            name: region_rule_name
            input:
                default_output,  # not a real dependency but if default_rule gets executed after this region_rule, it may recreate default_output_dir and hence delete region_rule.output
                **region_rule_vars,
            output:
                directory(region_customization['output_dir'])
            params:
                # Need to split the CLI because snakemake evaluates lambdas after the for loop and so the region_rule
                # variable in the lambda points always to the last region in the user_config
                cmd = region_rule_cli["command"].format(WORKING_DIR=WORKING_DIR),
                args = lambda wildcards, input: region_rule_cli["args"].format(input = input)
            log:
                f"{LOG_DIR}/{region_rule_name}.log"
            container:
                region_customization['container']
            shell:
                "{params.cmd} {params.args}"

        custom_regions_rules.append((workflow.get_rule(region_rule_name), region_rule_vars))

    merge_rule_vars = {}
    for custom_rule in custom_regions_rules:
        merge_rule_vars.update(custom_rule[1])
    merge_rule_name = get_merge_rule_name(rule_name)
    merge_rule_flag = f"{WORKING_DIR}/{merge_rule_name.replace('merge_', 'merged_')}.log"
    ##>merge_rule : Merge the outputs from region-specific rules into the default rule output
    rule:
        name: merge_rule_name
        input:
            [custom_rule[0].output for custom_rule in custom_regions_rules],
            hierarchy = merge_rule_vars["hierarchy"],
            annotation = merge_rule_vars["annotation_ccfv3"],
            default_output_dir = default_output_dir,
            default_output_file = default_output_file
        output:
            dir = directory(merged_output),
            flag = touch(merge_rule_flag)
        params:
            user_rule = user_rule,
            metadata_path = default_output.metadata if rule_name in ["placement_hints"] else None
        log:
            f"{LOG_DIR}/{merge_rule_name}.log"
        script:
            "../scripts/merge_rule.py"

    workflow.get_rule(target_rule).set_input(workflow.get_rule(merge_rule_name).output.flag)
