#!/usr/bin/env python

from __future__ import annotations

import csv
import getopt
import os.path
import pickle
import datetime
import subprocess
import sys

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import Select

from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.firefox.service import Service as FirefoxService
from selenium.webdriver.edge.service import Service as EdgeService

from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.core.utils import ChromeType
from webdriver_manager.firefox import GeckoDriverManager
from webdriver_manager.microsoft import EdgeChromiumDriverManager

from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.firefox.options import Options as FirefoxOptions
from selenium.webdriver.edge.options import Options as EdgeOptions

# from chromedriver_py import binary_path

from SRA_Importer.utilities import run_cmd, ReadsData, qiime2_version

CONDA_PREFIX = os.environ.get("CONDA_PREFIX", None)


def process_input(opts, dir_path):
    opts = dict(opts)

    if "-h" in opts or "--help" in opts:
        help_str = "create-visualization:\n" \
                   "--------------------\n" \
                   "The first stage is in charge of taking the data from the sra database and make a visualization of the reads." \
                   "The visualization purpose is to make better assessment of the range in which one should trim and truncate the reads generated by qiime2.\n\n" \
                   "Parameters:\n" \
                   "----------\n" \
                   " [-h] / [--help]: print documentation" \
                   " [--sra-study]: SRA experiment code.\n" \
                   " [--acc-list]: Accession list file. This file must be stored locally.\n" \
                   " [--output-vis-path]: An output path for the visualisation. (Optional)\n" \
                   " [--web-browser]: A web browser installed on the machine. Options are (chrome/chromium/brave/firefox/edge/opera)\n\n" \
                   "Note: Only one of the following must be given `--sra-study`/`--acc-list`. If both are given, an error will be raised.\n\n" \
                   "Usage: " \
                   "create_visualization --sra-study <sra-study-code> --acc-list <accession-list-file-path> --output-vis-path <final-output-path-of-visualisation>"
        print(help_str)
        sys.exit()

    if '--sra-study' not in opts and "--acc-list" not in opts:
        print("one of '--sra-study' or '--acc-list' must be given.")
        sys.exit(2)

    if '--sra-study' in opts and "--acc-list" in opts:
        print("one of '--sra-study' or '--acc-list' must be given. Cannot get both.")
        sys.exit(2)

    if '--acc-list' in opts:
        if not (os.path.exists(opts['--acc-list']) and os.path.isfile(opts['--acc-list'])):
            print("The given acc-list does not exist or is not a file.")
            sys.exit(2)
        acc_list = opts['--acc-list']
        sra_study = ""
    else:
        sra_study = opts['--sra-study']
        acc_list = ""

    if '--output-vis-path' not in opts:
        output_vis_path = ""
    else:
        t = os.path.join(*os.path.split(opts['--output-vis-path'])[:-1])
        if not (os.path.exists(t) and os.path.isdir(t) and opts['--output-vis-path'].split(".")[-1] == "qzv"):
            print("Invalid output-vis-path. output-vis-path must be a qzv file and be located in an existing directory."
                  "\noutput-vis-path will be saved in default location, under the created direction inside 'vis'")
            output_vis_path = ""
        else:
            output_vis_path = opts['--output-vis-path']

    if sra_study == "":
        return sra_study, acc_list, output_vis_path, None

    if "--web-browser" not in opts:
        print("--web-browser must be given with --sra-study given")
        sys.exit(2)
    elif opts["--web-browser"] not in {"chrome", "chromium", "brave", "firefox", "edge"}:
        print("The given web-browser is not supported!")
        sys.exit(2)
    else:
        services = {
            "chrome": (webdriver.Chrome, ChromeService, ChromeDriverManager()),
            "chromium": (webdriver.Chrome, ChromeService, ChromeDriverManager(chrome_type=ChromeType.CHROMIUM)),
            "brave": (webdriver.Chrome, ChromeService, ChromeDriverManager(chrome_type=ChromeType.BRAVE)),
            "firefox": (webdriver.Firefox, FirefoxService, GeckoDriverManager()),
            "edge": (webdriver.Edge, EdgeService, EdgeChromiumDriverManager())
        }
        options = {
            "chrome": (
                ChromeOptions, lambda x: x.add_experimental_option("prefs", {"download.default_directory": dir_path}),
                lambda x: x.headless(True)),
            "chromium": (
                ChromeOptions, lambda x: x.add_experimental_option("prefs", {"download.default_directory": dir_path}),
                lambda x: x.headless(True)),
            "brave": (
                ChromeOptions, lambda x: x.add_experimental_option("prefs", {"download.default_directory": dir_path}),
                lambda x: x.headless(True)),
            "firefox": (FirefoxOptions, lambda x: x.set_preference("browser.download.dir", dir_path),
                        lambda x: x.add_argument("--headless"),
                        lambda x: x.set_preference('browser.helperApps.neverAsk.saveToDisk', "application/sra")),
            "edge": (
                EdgeOptions, lambda x: x.add_experimental_option("prefs", {"download.default_directory": dir_path}),
                lambda x: x.headless(True))
        }

        service = services[opts["--web-browser"]]
        options_tuple = options[opts["--web-browser"]]
        option = options_tuple[0]()
        for i in range(1, len(options_tuple)):
            options_tuple[i](option)

        web_driver = service[0](service=service[1](service[2].install()), options=option)

    return sra_study, acc_list, output_vis_path, web_driver


def download_acc_list(driver: webdriver.Chrome | webdriver.Firefox | webdriver.Edge | None, sra_study: str):
    """retrieve the Accession list from the sra website"""
    driver.get(f"https://www.ncbi.nlm.nih.gov/sra?term={sra_study}&cmd=DetailsSearch")

    driver.find_element(value="sendto").click()

    driver.find_element(value="dest_File").click()

    select = Select(driver.find_element(value='file_format'))
    select.select_by_value("acclist")

    driver.find_element(by=By.NAME,
                        value="EntrezSystem2.PEntrez.Sra.Sra_ResultsPanel.Sra_DisplayBar.SendToSubmit").click()


def download_data_from_sra(dir_path: str, sra_study: str = "", acc_list: str = "",
                           driver: webdriver.Chrome | webdriver.Firefox | webdriver.Edge | None = None):
    """
    return True if downloaded the data successfully, False O/W
    """
    run_cmd(["mkdir", os.path.join(dir_path, "sra")])
    if sra_study != "":
        download_acc_list(driver, sra_study)
        run_cmd(['prefetch', "--option-file", os.path.join(dir_path, "SraAccList.txt"), "--output-directory",
                 os.path.join(dir_path, "sra")])
        return
    run_cmd(['prefetch', "--option-file", acc_list, "--output-directory", os.path.join(dir_path, "sra")])


def sra_to_fastq(dir_path: str):
    run_cmd(["mkdir", os.path.join(dir_path, "fastq")])
    for sra_dir in os.listdir(os.path.join(dir_path, "sra")):
        sra_file = os.listdir(os.path.join(dir_path, "sra", sra_dir))[0]
        sra_path = os.path.join(dir_path, "sra", sra_dir, sra_file)
        fastq_path = os.path.join(dir_path, "fastq")
        run_cmd(["fasterq-dump", "--split-files", sra_path, "-O", fastq_path])

    # check if reads include fwd and rev
    fastqs = sorted(os.listdir(os.path.join(dir_path, "fastq")))[:3]
    if len(set([fastq.split("_")[0] for fastq in fastqs])) == 1:
        return ReadsData(dir_path, fwd=True, rev=True)
    return ReadsData(dir_path, fwd=True, rev=False)


def create_manifest(reads_data: ReadsData):
    fastq_path = os.path.join(reads_data.dir_path, "fastq")
    if not reads_data.rev:
        files = [os.path.join(fastq_path, f) for f in os.listdir(fastq_path)
                 if os.path.isfile(os.path.join(fastq_path, f))]
        names = [f.split('/')[-1].split('.')[0] for f in files]

        with open(os.path.join(reads_data.dir_path, 'manifest.tsv'), 'w') as manifest:
            tsv_writer = csv.writer(manifest, delimiter='\t')
            tsv_writer.writerow(["SampleID", "absolute-filepath"])
            for n, f in zip(*(names, files)):
                tsv_writer.writerow([n, f])
        return

    files_fwd = sorted([os.path.join(fastq_path, f) for f in os.listdir(fastq_path)
                        if os.path.isfile(os.path.join(fastq_path, f)) and "_1" in f])
    files_rev = sorted([os.path.join(fastq_path, f) for f in os.listdir(fastq_path)
                        if os.path.isfile(os.path.join(fastq_path, f)) and "_2" in f])
    names = sorted([f.split('.')[0] for f in os.path.join(reads_data.dir_path, "sra")])

    with open(os.path.join(reads_data.dir_path, 'manifest.tsv'), 'w') as manifest:
        tsv_writer = csv.writer(manifest, delimiter='\t')
        tsv_writer.writerow(["SampleID", "forward-absolute-filepath", "reverse-absolute-filepath"])
        for n, ff, fr in zip(*(names, files_fwd, files_rev)):
            tsv_writer.writerow([n, ff, fr])


def conda_activate_qiime2():
    qiime_version = qiime2_version()

    split_path = list(os.path.split(CONDA_PREFIX))
    while "conda" not in split_path[-1]:
        split_path = list(os.path.split(split_path[0]))
    conda_base = "/".join(split_path)

    cmd = '. ' + conda_base + f'/etc/profile.d/conda.sh && conda activate {qiime_version}'
    subprocess.call(cmd, shell=True, executable='/bin/bash')


def qiime_import(reads_data: ReadsData):
    run_cmd(["mkdir", os.path.join(reads_data.dir_path, "qza")])
    qza_path = os.path.join(reads_data.dir_path, "qza")
    paired = reads_data.rev and reads_data.fwd

    output_path = os.path.join(qza_path, f"demux-{'paired' if paired else 'single'}-end.qza")
    command = [
        "qiime", "tools", "import",
        "--type", f"'SampleData[{'PairedEndSequencesWithQuality' if paired else 'SequencesWithQuality'}]'",
        "--input-path", f"{os.path.join(reads_data.dir_path, 'manifest.tsv')}",
        "--input-format", "PairedEndFastqManifestPhred33V2" if paired else "SingleEndFastqManifestPhred33V2",
        "--output-path", output_path
    ]
    run_cmd(command)
    return output_path


def qiime_demux(reads_data: ReadsData, input_path: str, output_vis_path: str = ""):
    if output_vis_path == "":
        run_cmd(["mkdir", os.path.join(reads_data.dir_path, "vis")])
        output_path = os.path.join(reads_data.dir_path, "vis", os.path.split(input_path)[-1].split(".")[0] + ".qzv")
    else:
        output_path = output_vis_path

    command = [
        "qiime", "demux", "summarize",
        "--i-data", input_path,
        "--o-visualization", output_path
    ]
    run_cmd(command)
    return output_path


def main_importer():
    start_import = datetime.datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
    run_cmd(["mkdir", f"SRA-Importer-{start_import}"])
    dir_path = os.path.join(os.path.abspath("."), f"SRA-Importer-{start_import}")

    def usage():
        print(f"usage: create_vis.py --sra-study <sra-study-code> --acc-list <accession-list-file-path> "
              f"--output-vis-path <final-output-path-of-visualisation>")

    try:
        opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "sra-study=", "acc-list=",
                                                       "output-vis-path=", "web-browser="])
        sra_study, acc_list, output_vis_path, driver = process_input(opts, dir_path)
    except getopt.GetoptError as err:
        # print help information and exit:
        print(err)  # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    download_data_from_sra(dir_path, sra_study, acc_list, driver)

    reads_data = sra_to_fastq(dir_path)

    create_manifest(reads_data)

    conda_activate_qiime2()

    output_path = qiime_import(reads_data)

    vis_path = qiime_demux(reads_data, output_path, output_vis_path)

    pickle.dump(reads_data, open(os.path.join(reads_data.dir_path, "reads_data.pkl"), "wb"))

    print(f"Visualization file is located in {vis_path}\n"
          f"Please drag this file to https://view.qiime2.org/ and continue.\n")
    if reads_data.fwd and reads_data.rev:
        print(f"Note: The data has both forward and reverse reads.\n"
              f"Therefore, you must give the parameters '--trim' and '--trunc' of 'export.py' "
              f"two integers values seperated with a comma without space between. "
              f"The first place related to the forward read and the second to the reverse.\n"
              f"For example: export.py --trim 20,28 --trunc 200,200")
    else:
        print(f"Note: The data has only a forward read.\n"
              f"Therefore, you must give the parameters '--trim' and '--trunc' of 'export.py' "
              f"exactly one integers value which is related to the forward read.\n"
              f"For example: export.py --trim 20 --trunc 200")


if __name__ == '__main__':
    main_importer()
