import os
import sys
import pathlib
import shutil
import json

import pandas as pd
import pyranges as pr
import snakemake.utils

import capcruncher.pipeline.utils
import importlib.util
from typing import Literal
import itertools


snakemake.utils.min_version('7.19.1')


configfile: "capcruncher_config.yml"


container: "library://asmith151/capcruncher/capcruncher:latest"


# Pipeline set-up
capcruncher.pipeline.utils.format_config_dict(config)

## Get fastq files
FASTQ_SAMPLES = capcruncher.pipeline.utils.FastqSamples.from_files(
    list(pathlib.Path(".").glob("*.fastq*"))
)

## Convert FASTQ files to design matrix
if os.path.exists(config["analysis"].get("design", None)):
    DESIGN = pd.read_table(
        config["analysis"]["design"], sep=r"\s+|,|\t", engine="python"
    )
else:
    DESIGN = FASTQ_SAMPLES.design

## Read viewpoints
VIEWPOINTS = config["analysis"]["viewpoints"]
VIEWPOINT_NAMES = [
    str(name)
    for name in pd.read_table(
        VIEWPOINTS, names=["chrom", "start", "end", "name"], header=None
    )["name"].to_list()
]


N_SAMPLES = DESIGN["sample"].nunique()
ANALYSIS_METHOD = config["analysis"].get("method", "capture")
BIN_SIZES = capcruncher.pipeline.utils.get_bin_sizes(config)
HIGH_NUMBER_OF_VIEWPOINTS = capcruncher.pipeline.utils.has_high_viewpoint_number(
    VIEWPOINTS, config
)
IGNORE_MULTIPLE_FRAGMENTS_PER_VIEWPOINT = config["analysis"].get(
    "ignore_multiple_fragments_per_viewpoint", False
)

# Details
SUMMARY_METHODS = [
    m
    for m in re.split(r"[,;\s+]", config["compare"].get("summary_methods", "mean,"))
    if m
]


## Optional
AGGREGATE_SAMPLES = DESIGN["sample"].nunique() > 1
COMPARE_SAMPLES = DESIGN["condition"].nunique() > 1
PERFORM_DIFFERENTIAL_ANALYSIS = (
    (config["differential"]["contrast"] in DESIGN.columns)
    and COMPARE_SAMPLES
    and (ANALYSIS_METHOD in ["capture", "tri"])
)
PERFORM_PLOTTING = capcruncher.pipeline.utils.can_perform_plotting(config)
PERFORM_BINNING = capcruncher.pipeline.utils.can_perform_binning(config)

## Pipeline variables
ASSAY = config["analysis"]["method"]
SAMPLE_NAMES = FASTQ_SAMPLES.sample_names_all

# CLEANUP = "full" if config["analysis"].get("cleanup", False) else "partial"
CLEANUP = False


include: "rules/digest.smk"
include: "rules/fastq.smk"
include: "rules/qc.smk"
include: "rules/align.smk"
include: "rules/annotate.smk"
include: "rules/filter.smk"
include: "rules/pileup.smk"
include: "rules/compare.smk"
include: "rules/statistics.smk"
include: "rules/visualise.smk"


wildcard_constraints:
    sample="|".join(SAMPLE_NAMES),
    part=r"\d+",
    viewpoint="|".join(VIEWPOINT_NAMES),
    combined="|".join(["flashed", "pe"]),


rule all:
    input:
        qc=rules.multiqc.output[0],
        report="capcruncher_output/results/capcruncher_report.html",
        pileups=capcruncher.pipeline.utils.get_pileups(
            assay=ASSAY,
            design=DESIGN,
            samples_aggregate=AGGREGATE_SAMPLES,
            samples_compare=COMPARE_SAMPLES,
            sample_names=SAMPLE_NAMES,
            summary_methods=SUMMARY_METHODS,
            viewpoints=VIEWPOINT_NAMES,
        ),
        counts=expand(
            "capcruncher_output/results/{sample}/{sample}.hdf5",
            sample=SAMPLE_NAMES,
        ),
        hub=rules.create_ucsc_hub.output[0]
        if ANALYSIS_METHOD in ["capture", "tri"]
        else [],
        differential=expand(
            "capcruncher_output/results/differential/{viewpoint}",
            viewpoint=VIEWPOINT_NAMES,
        )
        if PERFORM_DIFFERENTIAL_ANALYSIS
        else [],
        plots=expand(
            "capcruncher_output/results/figures/{viewpoint}.pdf",
            viewpoint=VIEWPOINT_NAMES,
        )
        if PERFORM_PLOTTING
        else [],


onerror:
    log_out = "capcruncher_error.log"
    shutil.copyfile(log, log_out)
    print(
        f"An error occurred. Please check the log file {log_out} for more information."
    )


onsuccess:
    log_out = "capcruncher.log"
    shutil.copyfile(log, log_out)
    print(f"Pipeline completed successfully. See {log_out} for more information.")

    if CLEANUP == "full":
        shutil.rmtree("capcruncher_output/interim/")

    elif CLEANUP == "partial" and pathlib.Path("capcruncher_output/interim/").exists():
        import subprocess

        files_to_remove = []
        # Split files
        for sample_name in SAMPLE_NAMES:
            split_dir = f"capcruncher_output/interim/fastq/split/{sample_name}"
            flashed_dir = f"capcruncher_output/interim/fastq/flashed/{sample_name}"
            rebalanced_dir = (
                f"capcruncher_output/interim/fastq/rebalanced/{sample_name}"
            )
            files_to_remove.extend(
                pathlib.Path(split_dir).glob(f"{sample_name}_part*.fastq.gz")
            )
            files_to_remove.extend(
                pathlib.Path(flashed_dir).glob(f"{sample_name}_part*.fastq.gz")
            )
            for combined in ["flashed", "pe"]:
                files_to_remove.extend(
                    (pathlib.Path(rebalanced_dir) / combined).glob(
                        f"{sample_name}_part*.fastq.gz"
                    )
                )
        for f in files_to_remove:
            f.unlink()
