#!/usr/bin/env python3


import argparse
import os
import sys
from typing import Dict, List
from Deepurify.IOUtils import writePickle

from Deepurify.RUN_Functions import cleanMAGs
from Deepurify.DataTools.DataUtils import insert


def bulid_tree(weight_file_path: str) -> Dict:
    def split_func(oneLine: str) -> List:
        levelsInfor = oneLine.split("@")
        return levelsInfor

    taxonomyTree = {"TaxoLevel": "superkingdom", "Name": "bacteria", "Children": []}
    with open(weight_file_path, mode="r") as rh:
        k = 0
        for line in rh:
            info = line.strip("\n").split("\t")
            insert(split_func(info[0]), taxonomyTree)
        k += 1
    return taxonomyTree


def build_taxo_vocabulary(weight_file_path: str) -> Dict[str, int]:
    vocab_dict = {"[PAD]": 0}
    k = 1
    with open(weight_file_path, "r") as rh:
        for line in rh:
            split_info = line.strip("\n").split("@")
            for word in split_info:
                vocab_dict[word] = k
                k += 1
    return vocab_dict


if __name__ == "__main__":
    myparser = argparse.ArgumentParser(
        prog=os.path.basename(sys.argv[0]), description="Deepurify is a tool to improving the quality of MAGs by decontaminating."
    )
    subparsers = myparser.add_subparsers(dest="command")

    clean_parser = subparsers.add_parser("clean", help="Filtering the contamination in MAGs.")

    # Add parameters
    clean_parser.add_argument(
        "-i",
        "--input_path",
        required=True,
        help="The input folder of MAGs.")
    clean_parser.add_argument(
        "-o",
        "--output_path",
        required=True,
        help="The output folder of the clean MAGs")
    clean_parser.add_argument(
        "--bin_suffix",
        required=True,
        help="The bin suffix of MAGs",
        type=str)
    clean_parser.add_argument(
        "--gpu_num",
        default=0,
        help="The number of GPUs would be used. 0 means to use CPU. (ATTENTION: CPU is much slower !!!!)",
        type=int
    )
    clean_parser.add_argument(
        "--batch_size_per_gpu",
        default=4,
        help="The batch size for per GPU. The number of sequences would be loaded to one GPU. It is useless if --gpu_num is 0.",
        type=int)
    clean_parser.add_argument(
        "--num_worker",
        default=2,
        type=int,
        help="The number of workers in one GPU or CPU. The batch size would divide this value for per worker."
    )

    ### optional ###
    clean_parser.add_argument(
        "--overlapping_ratio",
        default=0.5,
        type=float,
        help="The overlapping ratio if the length of contig exceeds the --cut_seq_length. Defaults to 0.5."
    )
    clean_parser.add_argument(
        "--cut_seq_length",
        default=8192,
        type=int,
        help="The length to cut the contig if the length of it longer than this value. Defaults to 8192.")
    clean_parser.add_argument(
        "--num_cpus_call_genes",
        default=64,
        type=int,
        help="The number of threads to call genes. Defaults to 64.")
    clean_parser.add_argument(
        "--hmm_acc_cutoff",
        default=0.7,
        type=float,
        help="The threshold when the hmm model decides to treat the called gene's sequence as SCG. Defaults to 0.7.",
    )
    clean_parser.add_argument(
        "--hmm_align_ratio_cutoff",
        default=0.4,
        type=float,
        help="The threshold of alignment coverage when the called gene's sequence aligned to the SCG. Defaults to 0.4.",
    )
    clean_parser.add_argument(
        "--estimate_completeness_threshold",
        default=0.5,
        type=float,
        help="The threshold of estimated completeness for filtering MAGs generated by applying those SCGs. Defaults to 0.5.",
    )
    clean_parser.add_argument(
        "--seq_length_threshold",
        default=550000,
        type=int,
        help="The threshold of a MAG's contigs' total length for filtering generated MAGs after applying SCGs.  Defaults to 550000.",
    )
    clean_parser.add_argument(
        "--checkM_parallel_num",
        default=3,
        choices=[1, 2, 3, 6],
        type=int,
        help="The number of processes to run CheckM simultaneously. Defaults to 3.")
    clean_parser.add_argument(
        "--num_cpus_per_checkm",
        default=25,
        type=int,
        help="The number of threads to run a CheckM process. Defaults to 25.")
    clean_parser.add_argument(
        "--dfs_or_greedy",
        default="dfs",
        choices=["dfs", "greedy"],
        type=str,
        help="Depth first searching or greedy searching to label a contig. Defaults to 'dfs'."
    )
    clean_parser.add_argument(
        "--topK",
        default=3,
        type=int,
        help="The Top-k nodes that have maximum cosine similarity with the contig encoded vector would be searched (Useless for greedy search). Defaults to 3.")
    clean_parser.add_argument(
        "--temp_output_folder",
        default=None,
        type=str,
        help="The path to store temporary files. Defaults to None",
    )
    clean_parser.add_argument(
        "--output_bins_meta_info_path",
        default=None,
        type=str,
        help="The path to record the meta informations of final clean MAGs. It records the completeness, contamination, quality, annotation of each MAG.  Defaults to None.",
    )
    clean_parser.add_argument(
        "--info_files_path",
        default=None,
        help="The path of DeepurifyInfoFiles folder. Defaults to None.",
        type=str
    )
    clean_parser.add_argument(
        "--model_weight_path",
        default=None,
        type=str,
        help="The path of model weight. (In DeepurifyInfoFiles folder) Defaults to None.")
    clean_parser.add_argument(
        "--taxo_vocab_path",
        default=None,
        type=str,
        help="The path of taxon vocabulary. (In DeepurifyInfoFiles folder) Defaults to None.",
    )
    clean_parser.add_argument(
        "--taxo_tree_path",
        default=None,
        type=str,
        help="The path of taxonomic tree. (In DeepurifyInfoFiles folder) Defaults to None.",
    )
    clean_parser.add_argument(
        "--taxo_lineage_vector_file_path",
        default=None,
        type=str,
        help="The path of taxonomic lineage encoded vectors. (In DeepurifyInfoFiles folder) Defaults to None. ",
    )
    clean_parser.add_argument(
        "--hmm_model_path",
        default=None,
        type=str,
        help="The path of SCGs' hmm file. (In DeepurifyInfoFiles folder) Defaults to None.",
    )

    #### build parser ####
    bulid_parser = subparsers.add_parser("build", help="Build the files like taxonomy tree and the taxonomy vocabulary for training.")
    # Add parameter
    bulid_parser.add_argument(
        "-i",
        "--input_taxo_lineage_weight_file_path",
        required=True,
        type=str,
        help="The path of the taxonomic lineages weights file. This file has two columns. " +
        "This first column is the taxonomic lineage of one species from phylum to species level, split with @ charactor. \n" +
        "The second colums is the weight value of the species." +
        "The two columns are split with '\\t'.")
    bulid_parser.add_argument(
        "-ot",
        "--output_tree_path",
        type=str,
        required=True,
        help="The output path of the taxonomy tree that build from your taxonomic lineages weights file.")
    bulid_parser.add_argument(
        "-ov",
        "--output_vocabulary_path",
        type=str,
        required=True,
        help="the output path of the taxonomy vocabulary that build from your taxonomic lineages weights file.")

    ### main part ###
    args = myparser.parse_args()

    if args.command == "clean":
        cleanMAGs(
            args.input_path,
            args.output_path,
            args.bin_suffix,
            args.gpu_num,
            args.batch_size_per_gpu,
            args.num_worker,
            args.overlapping_ratio,
            args.cut_seq_length,
            args.num_cpus_call_genes,
            args.hmm_acc_cutoff,
            args.hmm_align_ratio_cutoff,
            args.estimate_completeness_threshold,
            args.seq_length_threshold,
            args.checkM_parallel_num,
            args.num_cpus_per_checkm,
            args.dfs_or_greedy,
            args.topK,
            args.temp_output_folder,
            args.output_bins_meta_info_path,
            args.info_files_path,
            args.model_weight_path,
            args.taxo_vocab_path,
            args.taxo_tree_path,
            args.taxo_lineage_vector_file_path,
            args.hmm_model_path
        )

    elif args.command == "build":
        taxo_tree = bulid_tree(args.input_weight_file_path)
        writePickle(args.output_tree_path, taxo_tree)
        vocab = build_taxo_vocabulary(args.input_weight_file_path)
        with open(args.output_vocabulary_path, "w") as wh:
            for word, index in vocab.items():
                wh.write(word+"\t"+str(index) + "\n")
    else:
        print("Please use 'deepurify -h' or 'deepurify clean -h' or 'deepurify build -h' for helping.")
