#!python
# vim: set fileencoding=utf-8 :
#
# Copyright (C) 2022 Satria A. Kautsar
# Joint Genome Institute
"""
Run this script on a collection of BGCs in a genome
to get the resulting BiG-SLiCE features in a TSV format
"""
from pyhmmer import plan7, hmmer
from pyhmmer.easel import TextSequence, Alphabet
from os import path
import pandas as pd
import glob
from Bio import SeqIO
import argparse
import json


def main():

    # setup and check args
    args = get_arg_parser()

    if not args.antismash_folder and args.genome_file:
        if not args.locus_strings:
            print((
                "ERROR: please use -l/--locus_strings alongside -g/--genome_file"
            ))
            return 1
        mode = "genome_file"
    elif args.antismash_folder and not args.genome_file:
        mode = "regiongbks"
    else:
        print((
            "ERROR: please use ONE of the following: '-a/--antismash_folder'"
            " or '-g/--genome_file' to define your input"
        ))
        return 1

    if not path.exists(args.program_db_folder):
        print((
            "ERROR: can't find HMM Database folder at {}"
            " (use --program_db_folder to define a custom path)"
        ).format(args.program_db_folder))
        return 1

    if args.output_tsv and path.exists(args.output_tsv):
        print((
            "ERROR: path exists! {}"
        ).format(args.output_tsv))
        return 1
    if args.output_json and path.exists(args.output_json):
        print((
            "ERROR: path exists! {}"
        ).format(args.output_json))
        return 1
    if args.output_parquet and path.exists(args.output_parquet):
        print((
            "ERROR: path exists! {}"
        ).format(args.output_parquet))
        return 1

    # parse BGCs data
    if mode == "genome_file":
        args.locus_strings = args.locus_strings.split(",")
        bgcs = extract_bgcs_from_genome_loci(args.genome_file, args.locus_strings)
    elif mode == "regiongbks":
        bgcs = extract_bgcs_from_regiongbks(args.antismash_folder)

    # extract features
    features_df = extract_features(bgcs, args.program_db_folder)

    if args.output_tsv:
        features_df.to_csv(args.output_tsv, sep="\t")

    if args.output_json:
        with open(args.output_json, "w") as json_file:
            json_file.write(json.dumps(json.loads(features_df.to_json())))

    if args.output_parquet:
        features_df.to_parquet(args.output_parquet)

    print(features_df.to_csv())


def get_arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-a", "--antismash_folder", type=str, metavar='<path_folder>',
        help=("path to an antiSMASH result folder as the input")
    )

    parser.add_argument(
        "-g", "--genome_file", type=str, metavar='<path_gbk>',
        help=("path to an ANNOTATED genome file (in genbank format) as the input")
    )
    parser.add_argument(
        "-l", "--locus_strings", type=str, metavar='<delimited_list>',
        help=((
            "list of locus strings to be paired with --genome_file, use COMMA-delimited"
            " strings with the format <contig_name>:<nt_start>:<nt_end>"
        ))
    )

    parser.add_argument(
        "-t", "--num_threads", type=int, default=0,
        help=("number of CPUs to use for pyHMMer")
    )

    parser.add_argument(
        "--program_db_folder",
        metavar='<path>',
        default=path.join(path.dirname(
            path.realpath(__file__)),
            "bigslice-models"), type=str,
        help=("Path to the BiG-SLiCE HMM libraries (default: %(default)s)."))

    parser.add_argument(
        "--output_tsv", type=str, metavar='<path>',
        help=("output to a TSV file")
    )
    parser.add_argument(
        "--output_json", type=str, metavar='<path>',
        help=("output to a JSON file")
    )
    parser.add_argument(
        "--output_parquet", type=str, metavar='<path>',
        help=("output to a Parquet file")
    )

    return parser.parse_args()


def extract_bgcs_from_genome_loci(genome_path, locus_strings):
    
    # parse locus string into dict
    locus_dict = {}
    for locus_string in locus_strings:
        contig_id, start, end = locus_string.split(":")
        if contig_id not in locus_dict:
            locus_dict[contig_id] = []
        locus_dict[contig_id].append((
            min(int(start), int(end)),
            max(int(start), int(end))
        ))
        
    # parse record and extract bgcs
    results = []
    for contig in SeqIO.parse(genome_path, "genbank"):
        if contig.id not in locus_dict:
            continue
        for start, end in locus_dict[contig.id]:
            bgc = {
                "name": "{}:{}:{}-{}nt:{}".format(
                    genome_path, contig.id, start, end,
                    0 if (start == 1 or end >= len(contig)) else 1
                ),
                "cds": []
            }
            for feature in contig.features:
                if feature.type != "CDS":
                    continue
                if feature.location.end > end:
                    break
                elif feature.location.start + 1 >= start:
                    bgc["cds"].append(
                        TextSequence(name=bytes("{}-{}".format(
                            len(results),
                            len(bgc["cds"])
                        ), "utf-8"), sequence=feature.qualifiers["translation"][0]).digitize(Alphabet.amino())
                    )
            results.append(bgc)
            
    return results


def extract_bgcs_from_regiongbks(antismash_folder):

    # parse record and extract bgcs
    results = []
    for fp in glob.iglob(path.join(
        antismash_folder,
        "*.region*.gbk"
    )):
        for contig in SeqIO.parse(fp, "genbank"):
            contig_edge = False
            bgc = {
                "name": fp,
                "cds": []
            }
            for feature in contig.features:
                if feature.type == "CDS":
                    bgc["cds"].append(
                        TextSequence(name=bytes("{}-{}".format(
                            len(results),
                            len(bgc["cds"])
                        ), "utf-8"), sequence=feature.qualifiers["translation"][0]).digitize(Alphabet.amino())
                    )
                elif feature.type == "region":
                    contig_edge = feature.qualifiers["contig_edge"][0] == "True"
            bgc["name"] += ":1" if not contig_edge else ":0"
                    
            results.append(bgc)
            
    return results


def extract_features(bgcs, hmmdb_folder, top_k=3, num_cpus=0):
    
    # store hmmdb model data
    biosyn_pfams = []
    with open(path.join(hmmdb_folder, "biosynthetic_pfams", "Pfam-A.biosynthetic.hmm"), "r") as ii:
        for line in ii:
            if line.startswith("NAME "):
                biosyn_pfams.append(line.split("NAME ")[-1].lstrip().rstrip())
    core_pfams = []
    sub_pfams = []
    for fp in glob.iglob(path.join(hmmdb_folder, "sub_pfams", "hmm", "*.subpfams.hmm")):
        hmm_acc = path.basename(fp).split(".subpfams.hmm")[0]
        core_pfams.append(hmm_acc)
        with open(fp, "r") as ii:
            for line in ii:
                if line.startswith("NAME "):
                    sub_pfams.append(line.split("NAME ")[-1].lstrip().rstrip())
    
    # perform biosyn_scan
    biosyn_hits = []
    subpfam_to_scan = {}
    biosyn_pfam_model = path.join(hmmdb_folder, "biosynthetic_pfams", "Pfam-A.biosynthetic.hmm")
    sequences = [x for y in [bgc["cds"] for bgc in bgcs] for x in y]
    with plan7.HMMFile(biosyn_pfam_model) as hmm_file:
        for top_hits in hmmer.hmmsearch(
            hmm_file, sequences,
            cpus=num_cpus,
            bit_cutoffs="gathering"
            ):
            for hit in top_hits:
                if hit.best_domain.score < top_hits.domT:
                    continue
                bgc_id, cds_id = list(map(int, hit.name.decode().split("-")))
                hmm_name = top_hits.query_accession.decode()
                alignment = hit.best_domain.alignment
                biosyn_hits.append((bgc_id, hmm_name, 255))
                
                # check if need subpfam scan
                if hmm_name in core_pfams:
                    if hmm_name not in subpfam_to_scan:
                        subpfam_to_scan[hmm_name] = []
                    subpfam_to_scan[hmm_name].append(TextSequence(name=bytes("{}-{}".format(
                        bgc_id,
                        len(subpfam_to_scan[hmm_name])
                    ), "utf-8"), sequence=alignment.target_sequence).digitize(Alphabet.amino()))
                
    # perform subpfam_scan
    subpfam_hits = []
    for hmm_name, sequences in subpfam_to_scan.items():
        sub_pfam_model = path.join(
            hmmdb_folder, "sub_pfams", "hmm", "{}.subpfams.hmm".format(
                hmm_name
            )
        )
        with plan7.HMMFile(sub_pfam_model) as hmm_file:
            parsed = {}
            for top_hits in hmmer.hmmsearch(
                hmm_file, sequences,
                cpus=num_cpus,
                T=20, domT=20
                ):
                for hit in top_hits:
                    if hit.best_domain.score < top_hits.domT:
                        continue
                    hsp_name = hit.name.decode()
                    hmm_name = top_hits.query_name.decode()
                    score = hit.best_domain.score
                    if hsp_name not in parsed:
                        parsed[hsp_name] = {}
                    if hmm_name in parsed[hsp_name]:
                        parsed[hsp_name][hmm_name] = max(
                            score, parsed[hsp_name][hmm_name]
                        )
                    else:
                        parsed[hsp_name][hmm_name] = score
                    
            for hsp_name, hits in parsed.items():
                k = 0
                for hmm_name, score in sorted(hits.items(), key=lambda n: n[1], reverse=True):
                    if k >= top_k:
                        break
                    bgc_id, hsp_id = list(map(int, hsp_name.split("-")))
                    subpfam_hits.append((bgc_id, hmm_name, 255 - int((255 / top_k) * k)))
                    k += 1
    
    df = pd.DataFrame(
        [*biosyn_hits, *subpfam_hits],
        columns = ["bgc_id", "hmm_name", "value"]
    ).sort_values("value", ascending=False).drop_duplicates(["bgc_id", "hmm_name"])
    
    df = pd.pivot(
        df,
        index="bgc_id", columns="hmm_name", values="value"
    ).reindex(
        [*biosyn_pfams, *sub_pfams], axis="columns"
    ).fillna(0).astype(int)
    
    df.index = df.index.map(lambda i: bgcs[i]["name"])
    
    return df

if __name__ == "__main__":
    main()