SNAKE_DIR = Path(workflow.basedir)
TEMPLATE_DIR = SNAKE_DIR / "templates"
SCRIPT_DIR = SNAKE_DIR / "scripts"
OUTDIR = Path(config["output_dir"])


rule beast:
    input:
        beast_XML_file = OUTDIR / "skygrid.xml",
    output:
        log_file = OUTDIR / "skygrid.log",
        tree_file = OUTDIR / "skygrid.trees",
        stdout = OUTDIR / "skygrid.out",
    params:
        beast = config["beast_params"],
    conda:
        "envs/beast.yaml",
    shell:
        """
        beast \
        {params.beast} \
        {input.beast_XML_file} \
        > {output.stdout}
        """

constant_sites = config.get("constant_sites")

rule create_beast_xml:
    input:
        alignment = config["alignment"],
    output:
        beast_XML_file = OUTDIR / "skygrid.xml",
    params:
        template = TEMPLATE_DIR / "skygrid.jinja.xml",
        dimensions = config.get("dimensions"),
        cutoff = config.get("cutoff"),
        clock = config.get("clock"),
        chain_length = config.get("chain_length"),
        samples = config.get("samples"),
        constant_sites = f'--constant-sites "{constant_sites}"' if constant_sites  else "",
    shell:
        """
        python {SCRIPT_DIR}/populate_skygrid_template.py \
            {params.template} \
            {input.alignment} \
            --dimensions {params.dimensions} \
            --cutoff {params.cutoff} \
            --output {output.beast_XML_file} \
            --clock {params.clock} \
            --chain-length {params.chain_length} \
            --samples {params.samples} \
            {params.constant_sites}
        """

