#!/usr/bin/env python

from __future__ import annotations

import getopt
import os
import pickle
import sys

from selenium import webdriver

from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.firefox.service import Service as FirefoxService
from selenium.webdriver.ie.service import Service as IEService
from selenium.webdriver.edge.service import Service as EdgeService

from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.core.utils import ChromeType
from webdriver_manager.firefox import GeckoDriverManager
from webdriver_manager.microsoft import IEDriverManager, EdgeChromiumDriverManager

from SRA_Importer.utilities import ReadsData, run_cmd, qiime2_version


def process_input(opts):
    opts = dict(opts)

    if "-h" in opts or "--help" in opts:
        help_str = "export-data:\n" \
                   "-----------\n" \
                   "The second stage is in charge of creating OTU and Taxonomy tables and export them into a usable file formats.\n\n" \
                   "Parameters:\n" \
                   "----------\n" \
                   " [-h] / [--help]: print documentation\n" \
                   " [--output-dir]: The path of the directory created by the first stage.\n" \
                   " [--otu-output-file]: An output path for the OTU table. The directory must exist, and the file's format must be `txt`/`tsv`.\n" \
                   " [--taxonomy-output-file]: An output path for the taxonomy table. The directory must exist, and the file's format must be `tsv`.\n" \
                   "DADA2 parameters:\n" \
                   " [--trim]: a non-negative integer of the right edge of the trimming range. If the reads are both forward and reverse 2 values should be given seperated with comma: `20,28`\n" \
                   " [--trunc]: a non-negative integer of the left edge of the truncating range. If the reads are both forward and reverse 2 values should be given seperated with comma: `200,220`\n" \
                   " [--threads]: Number of threads to run on. Default is `12`. (Optional)\n\n" \
                   "Note: All the parameters except `--threads` must be given.\n" \
                   "Usage: export-data --output-dir <output-directory-path> --trim <trim-from> --trunc <trunc-to> --threads <number-of-threads-to-use> --otu-output-file <otu-output-file-path> --taxonomy-output-file <taxonomy-output-file-path>"
        print(help_str)
        sys.exit()

    if '--output-dir' not in opts:
        print("output-dir is not found, and must be given")
        sys.exit(2)
    elif not (os.path.exists(opts['--output-dir']) and os.path.isdir(opts['--output-dir'])):
        print("The given output-dir does not exist")
        sys.exit(2)
    else:
        dir_path = opts['--output-dir']
        reads_data: ReadsData = pickle.load(open(os.path.join(dir_path, "reads_data.pkl"), "rb"))

    if '--trim' not in opts:
        print("trim is not found, and must be given")
        sys.exit(2)
    else:
        if reads_data.fwd and reads_data.rev:
            t = opts['--trim'].split(",")
            if len(t) != 2:
                print("The read consist of both forward and reverse, so 'trim' must be two non-negative integers.")
                sys.exit(2)
            if not (t[0].isnumeric() and t[1].isnumeric()):
                print("'trim' must get two non-negative integers.")
                sys.exit(2)
            trim = (int(t[0]), int(t[1]))
        else:
            t = opts['--trim'].split(",")
            if len(t) != 1:
                print("The read consist only forward, so 'trim' must be one non-negative integer.")
                sys.exit(2)
            if not opts['--trim'].isnumeric():
                print("'trim' must get one non-negative integer.")
                sys.exit(2)
            trim = int(opts['--trim'])

    if '--trunc' not in opts:
        print("trunc is not found, and must be given")
        sys.exit(2)
    else:
        if reads_data.fwd and reads_data.rev:
            t = opts['--trunc'].split(",")
            if len(t) != 2:
                print("The read consist of both forward and reverse, so 'trunc' must be two non-negative integers.")
                sys.exit(2)
            if not (t[0].isnumeric() and t[1].isnumeric()):
                print("'trunc' must get two non-negative integers.")
                sys.exit(2)
            trunc = (int(t[0]), int(t[1]))
        else:
            t = opts['--trunc'].split(",")
            if len(t) != 1:
                print("The read consist only forward, so 'trunc' must be one non-negative integer.")
                sys.exit(2)
            if not opts['--trunc'].isnumeric():
                print("'trunc' must get one non-negative integer.")
                sys.exit(2)
            trunc = int(opts['--trunc'])

    if '--threads' not in opts:
        threads = 12
    else:
        if not opts['--threads'].isnumeric() or opts['--threads'] == '0':
            print("'threads' must be a positive integer. set to default: threads=12")
            threads = 12
        else:
            threads = int(opts['--threads'])

    """otu-output-file=", "taxonomy-output-file="""
    if "--otu-output-file" not in opts:
        print("otu-output-file is not found, and must be given")
        sys.exit(2)
    else:
        t = opts["--otu-output-file"]
        dir_ = os.path.join(*os.path.split(t)[:-1])
        if not (os.path.exists(dir_) and os.path.isdir(dir_)):
            print(f"The directory of the file given in otu-output-file is not found. "
                  f"Create directory {dir_} or change to an existing one.")
            sys.exit(2)
        if t.split(".")[-1] not in {"tsv", 'txt'}:
            print(f"otu-output-file must be a tsv/txt file. Instead got a {t.split('.')[-1]} file.")
            sys.exit(2)
        otu_output_file = t

    if "--taxonomy-output-file" not in opts:
        print("taxonomy-output-file is not found, and must be given")
        sys.exit(2)
    else:
        t = opts["--taxonomy-output-file"]
        dir_ = os.path.join(*os.path.split(t)[:-1])
        if not (os.path.exists(dir_) and os.path.isdir(dir_)):
            print(f"The directory of the file given in taxonomy-output-file is not found. "
                  f"Create directory {dir_} or change to an existing one.")
            sys.exit(2)
        if t.split(".")[-1] != "tsv":
            print(f"taxonomy-output-file must be a tsv file. Instead got a {t.split('.')[-1]} file.")
            sys.exit(2)
        taxonomy_output_file = t

    services = {
        "chrome": ChromeService(ChromeDriverManager().install()),
        "chromium": ChromeService(ChromeDriverManager(chrome_type=ChromeType.CHROMIUM).install()),
        "brave": ChromeService(ChromeDriverManager(chrome_type=ChromeType.BRAVE).install()),
        "firefox": FirefoxService(GeckoDriverManager().install()),
        "ie": IEService(IEDriverManager().install()),
        "edge": EdgeService(EdgeChromiumDriverManager().install())
    }

    if "--web-browser" not in opts:
        print("--web-browser must be given with --sra-study given")
        sys.exit(2)
    elif opts["--web-browser"] not in services:
        print("The given web-browser is not supported!")
        sys.exit(2)
    else:
        service = services[opts["--web-browser"]]

    return reads_data, trim, trunc, threads, otu_output_file, taxonomy_output_file, service


def qiime_dada2(reads_data: ReadsData, input_path: str,
                left: int | tuple[int, int], right: int | tuple[int, int], threads: int = 12):
    paired = reads_data.fwd and reads_data.rev

    trim_range = ["--p-trim-left-f", left[0], "--p-trim-left-r", left[1]] if paired \
        else ["--p-trim-left", left]
    trunc_range = ["--p-trunc-len-f", right[0], "--p-trunc-len-r", right[1]] if paired \
        else ["--p-trunc-len", right]

    command = [
                  "qiime", "dada2", "denoise-paired" if paired else "denoise-single",
                  "--i-demultiplexed-seqs", input_path,
              ] + trim_range + trunc_range + [
                  "--o-table", os.path.join(reads_data.dir_path, "qza", "dada2_table.qza"),
                  "--p-n-threads", threads,
                  "--p-chimera-method", "consensus",
                  "--o-representative-sequences", os.path.join(reads_data.dir_path, "qza", "dada2_rep-seqs.qza"),
                  "--o-denoising-stats", os.path.join(reads_data.dir_path, "qza", "dada2_denoising-stats.qza"),
              ]
    run_cmd(command)


def cluster_features(reads_data: ReadsData):
    qza_path = lambda filename: os.path.join(reads_data.dir_path, "qza", filename)
    command = [
        "qiime", "vsearch", "cluster-features-de-novo",
        "--i-table", qza_path("dada2_table.qza"),
        "--i-sequences", qza_path("dada2_rep-seqs.qza"),
        "--p-perc-identity", "0.99",
        "--o-clustered-table", qza_path("table-dn-99.qza"),
        "--o-clustered-sequences", qza_path("rep-seqs-dn-99.qza")
    ]
    run_cmd(command)


def download_taxonomy_classifier(reads_data: ReadsData, qiime_version: str,
                                 service: ChromeService | FirefoxService | IEService | EdgeService | None = None):
    options = webdriver.ChromeOptions()
    options.add_argument(f"download.default_directory={reads_data.dir_path}")
    options.add_argument(f"headless")

    driver = webdriver.Chrome(service=service, chrome_options=options)
    driver.get(f"https://data.qiime2.org/{qiime_version}/common/gg-13-8-99-nb-classifier.qza")


def assign_taxonomy(reads_data: ReadsData):
    qza_path = lambda filename: os.path.join(reads_data.dir_path, "qza", filename)
    command = [
        "qiime", "feature-classifier", "classify-sklearn",
        "--i-reads", qza_path("rep-seqs-dn-99.qza"),
        "--i-classifier", os.path.join(reads_data.dir_path, "gg-13-8-99-nb-classifier.qza"),
        "--o-classification", qza_path("gg-13-8-99-nb-classified.qza")
    ]
    run_cmd(command)


def clean_taxonomy1(reads_data: ReadsData):
    qza_path = lambda filename: os.path.join(reads_data.dir_path, "qza", filename)
    command = [
        "qiime", "taxa", "filter-table",
        "--i-table", qza_path("table-dn-99.qza"),
        "--p-exclude", "mitochondria,chloroplast",
        "--o-filtered-table", qza_path("clean_table.qza")
    ]
    run_cmd(command)


def clean_taxonomy2(reads_data: ReadsData):
    qza_path = lambda filename: os.path.join(reads_data.dir_path, "qza", filename)
    command = [
        "qiime", "feature-table", "filter-features",
        "--i-table", qza_path("clean_table.qza"),
        "--p-min-samples", "3",
        "--p-min-frequency", "10",
        "--o-filtered-table", qza_path("feature-frequency-filtered-table.qza")
    ]
    run_cmd(command)


def export_otu(reads_data: ReadsData, output_file: str):
    # export
    command = [
        "qiime", "tools", "export",
        "--input-path", os.path.join(reads_data.dir_path, "qza", "feature-frequency-filtered-table.qza"),
        "--output-path", os.path.join(reads_data.dir_path, "exports")
    ]
    run_cmd(command)

    # convert
    command = [
        "biom", "convert",
        "-i", os.path.join(reads_data.dir_path, "exports", "feature-table.biom"),
        "-o", output_file
    ]
    run_cmd(command)


def export_taxonomy(reads_data: ReadsData, output_file: str):
    # export
    command = [
        "qiime", "tools", "export",
        "--input-path", os.path.join(reads_data.dir_path, "qza", "gg-13-8-99-nb-classified.qza"),
        "--output-path", os.path.join(reads_data.dir_path, "exports")
    ]
    run_cmd(command)

    # copy taxonomy.tsv to output_file
    command = ["cp", os.path.join(reads_data.dir_path, "exports", "taxonomy.tsv"), output_file]
    run_cmd(command)


def main_importer():
    def usage():
        print(f"usage: export.py --output-dir <output-directory-path> --trim <trim-from> "
              f"--trunc <trunc-to> --threads <number-of-threads-to-use> "
              f"--otu-output-file <otu-output-file-path> --taxonomy-output-file <taxonomy-output-file-path>")

    try:
        opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "output-dir=", "trim=", "trunc=",
                                                       "threads=", "otu-output-file=", "taxonomy-output-file=",
                                                       "web-browser="])
        reads_data, trim, trunc, threads, otu_output_file, taxonomy_output_file, service = process_input(opts)
    except getopt.GetoptError as err:
        # print help information and exit:
        print(err)  # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    paired = reads_data.rev and reads_data.fwd
    output_path = os.path.join(reads_data.dir_path, "qza", f"demux-{'paired' if paired else 'single'}-end.qza")

    qiime_dada2(reads_data, output_path, left=trim, right=trunc, threads=threads)

    cluster_features(reads_data)

    download_taxonomy_classifier(reads_data, qiime2_version().split("-")[1], service)

    assign_taxonomy(reads_data)

    run_cmd(["mkdir", os.path.join(reads_data.dir_path, "exports")])

    clean_taxonomy1(reads_data)
    clean_taxonomy2(reads_data)

    export_otu(reads_data, otu_output_file)
    export_taxonomy(reads_data, taxonomy_output_file)

    run_cmd(["conda", "deactivate"])


if __name__ == '__main__':
    main_importer()
