# -*- coding: utf-8

import os
import argparse

import anvio
import anvio.workflows as w
import anvio.terminal as terminal

from anvio.workflows.contigs import ContigsDBWorkflow
from anvio.errors import ConfigError

__author__ = "Alon Shaiber"
__copyright__ = "Copyright 2017, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "Alon Shaiber"
__email__ = "alon.shaiber@gmail.com"

run = terminal.Run()

slave_mode = False if 'workflows/contigs' in workflow.included[0] else True
if not slave_mode:
    # don't be confused, child. when things come to this point, the variable `config`
    # is already magically filled in by snakemake:
    M = ContigsDBWorkflow(argparse.Namespace(config=config, slave_mode=slave_mode))
    M.init()
    dirs_dict = M.dirs_dict

localrules: annotate_contigs_database


rule contigs_target_rule:
    input: expand(dirs_dict['CONTIGS_DIR'] + "/{group}-annotate_contigs_database.done", group=M.group_names)


rule annotate_contigs_database:
    '''This is a dummy rule to mark that annotation was completed for each database

    This allows workflows that inherit the contigs workflow to just track the output
    of this rule, instead of tracking the individual
    '''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-annotate_contigs_database.log"
    input: lambda wildcards: M.get_contigs_target_files()
    output: touch(dirs_dict['CONTIGS_DIR'] + "/{group}-annotate_contigs_database.done")


rule gunzip_fasta:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-gunzip_fasta.log"
    input: lambda wildcards: M.get_raw_fasta(wildcards, remove_gz_suffix=False)
    output: temp(os.path.join(M.dirs_dict['FASTA_DIR'], \
                            '{group}' + '-temp.fa'))
    threads: M.T('gunzip_fasta')
    resources: nodes = M.T('gunzip_fasta')
    shell: "gunzip < {input} > {output} 2>{log}"


rule anvi_script_reformat_fasta_prefix_only:
    '''
        Reformating the headers of the contigs fasta files.

        This is required to make sure taht the headers don't contain
        any charachters that anvi'o doesn't like.It give contigs
        meaningful names; so that if the group name is 'MYSAMPLE01', the
        contigs would look like this:
        > MYSAMPLE01_000000000001
        > MYSAMPLE01_000000000002

        We do it in two steps in this workflow with the rules:
        anvi_script_reformat_fasta_prefix_only and anvi_script_reformat_fasta_min_length
        The first rule anvi_script_reformat_fasta_prefix_only does all the formating
        of headers without considering the min_len parameter. The second rule takes
        the output of the first rule as input and considers the min_len param.
        The reasoning is that this way we have a fasta file that includes all the contigs
        from the raw fasta file, but with headers that match the fasta final fasta file.
        Hence, if later we decide we want to include shorter contigs, we can do it using
        the output of anvi_script_reformat_fasta_prefix_only and the headers would be consistant.
    '''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_script_reformat_fasta_prefix_only.log"
    input:
        contigs = M.get_raw_fasta
    output:
        # write protecting the contigs fasta file using protected() because
        # runnig the assembly is probably the most time consuming step and
        # we don't want anyone accidentaly deleting or changing this file.
        contigs = protected(dirs_dict["FASTA_DIR"] + "/{group}/{group}-contigs-prefix-formatted-only.fa"),
        report = dirs_dict["FASTA_DIR"] + "/{group}/{group}-reformat-report.txt"
    params:
        prefix = M.get_rule_param("anvi_script_reformat_fasta", "--prefix"),
        keep_ids = M.get_rule_param("anvi_script_reformat_fasta", "--keep-ids"),
        exclude_ids = M.get_rule_param("anvi_script_reformat_fasta", "--exclude-ids"),
        simplify_names = M.get_rule_param("anvi_script_reformat_fasta", "--simplify-names")
    threads: M.T('anvi_script_reformat_fasta')
    resources: nodes = M.T('anvi_script_reformat_fasta'),
    shell: w.r("""anvi-script-reformat-fasta {input} \
                                             -o {output.contigs} \
                                             -r {output.report} \
                                             {params.prefix} \
                                             {params.exclude_ids} \
                                             {params.keep_ids} \
                                             {params.simplify_names} >> {log} 2>&1""")


rule anvi_script_reformat_fasta:
    '''
        See docummentation for anvi_script_reformat_fasta_prefix_only
    '''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_script_reformat_fasta.log"
    input:
        contigs = dirs_dict["FASTA_DIR"] + "/{group}/{group}-contigs-prefix-formatted-only.fa"
    output:
        contigs = temp(os.path.join(dirs_dict["FASTA_DIR"], "{group}", "{group}-contigs.fa"))
    params:
        prefix = "{group}",
        min_len = M.get_rule_param("anvi_script_reformat_fasta", "--min-len"),
    threads: M.T('anvi_script_reformat_fasta')
    resources: nodes = M.T('anvi_script_reformat_fasta'),
    shell: w.r("""anvi-script-reformat-fasta {input} \
                                             -o {output.contigs} \
                                             {params.min_len} >> {log} 2>&1""")


rule reformat_external_gene_calls_table:
    version: 1.0
    log: os.path.join(dirs_dict["LOGS_DIR"], "{group}-reformat_external_gene_calls_table.log")
    input: unpack(M.get_input_for_reformat_external_gene_calls_table)
    output:
        os.path.join(dirs_dict['FASTA_DIR'], "{group}", "{group}-external-gene-calls.txt")
    threads: M.T('reformat_external_gene_calls_table')
    resources: nodes = M.T('reformat_external_gene_calls_table')
    run:
        import pandas as pd
        import anvio.utils as u
        import anvio.fastalib as f

        fa = f.ReadFasta(input.contigs[0])
        contigs = fa.ids

        reformat_report = pd.read_csv(input.reformat_report[0], sep = '\t',
                                      index_col=1, header=None)
        contig_dict = next(iter(reformat_report.to_dict().values()))
        external_gene_calls = u.get_TAB_delimited_file_as_dictionary(input.external_gene_calls)
        contigs_in_external_gene_calls_not_in_contigs_database = \
            [c for c in contig_dict.values() if c not in contigs]
        if contigs_in_external_gene_calls_not_in_contigs_database:
            warning = "The following contigs are in your external gene calls " +\
                         "but not in your reformatted fasta file (%s): %s. This " %(\
                             input.contigs[0],
                             ", ".join(contigs_in_external_gene_calls_not_in_contigs_database)) +\
                         "is no cause for concern since they have been removed due " +\
                         "to the --min-len criteria of anvi-script-reformat-fasta, " +\
                         "but we still thought you should know. Any gene that was called " +\
                         "within these contigs will naturally not be included in your contigs database."
            shell("echo '%s' > %s" % (warning, log))

        new_external_gene_calls = {}
        for g in external_gene_calls:
            contig_name = external_gene_calls[g]['contig']
            if contig_name not in contig_dict:
                raise ConfigError('When proccessing external gene calls for %s '
                                  'we bumped into the following problem: The '
                                  'following contig name: "%s" appears in your external '
                                  'gene calls file, but not in your fasta file.' % (wildcards.group, contig_name))
            new_contig_name = contig_dict[contig_name]
            if new_contig_name in contigs:
                new_external_gene_calls[g] = external_gene_calls[g].copy()
                new_external_gene_calls[g]['contig'] = new_contig_name
#            u.store_dict_as_TAB_delimited_file(new_external_gene_calls, output[0], key_header='gene_callers_id')
        D = pd.DataFrame.from_dict(new_external_gene_calls, orient='index')
        D.to_csv(output[0], index_label='gene_callers_id', sep='\t')


rule anvi_gen_contigs_database:
    """ Generates a contigs database using anvi-gen-contigs-database"""
    # Setting the version to the same as that of the contigs__version in anvi'o
    version: anvio.__contigs__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_gen_contigs_database.log"
    # depending on whether human contamination using centrifuge was done
    # or not, the input to this rule will be the raw assembly or the
    # filtered.
    input: unpack(M.get_input_for_anvi_gen_contigs_database)
    output:
        db = os.path.join(dirs_dict["CONTIGS_DIR"], "{group}-contigs.db")
    params:
        description = M.get_rule_param("anvi_gen_contigs_database", "--description"),
        skip_gene_calling = M.get_rule_param("anvi_gen_contigs_database", "--skip-gene-calling"),
        external_gene_calls = lambda wildcards: M.get_external_gene_calls_param(wildcards),
        ignore_internal_stop_codons = M.get_rule_param("anvi_gen_contigs_database", "--ignore-internal-stop-codons"),
        skip_mindful_splitting = M.get_rule_param("anvi_gen_contigs_database", "--skip-mindful-splitting"),
        contigs_fasta = M.get_rule_param("anvi_gen_contigs_database", "--contigs-fasta"),
        project_name = M.get_rule_param("anvi_gen_contigs_database", "--project-name"),
        split_length = M.get_rule_param("anvi_gen_contigs_database", "--split-length"),
        kmer_size = M.get_rule_param("anvi_gen_contigs_database", "--kmer-size"),
        prodigal_translation_table = lambda wildcards: M.get_prodigal_translation_table_param(wildcards)
    threads: M.T('anvi_gen_contigs_database')
    resources: nodes = M.T('anvi_gen_contigs_database'),
    shell: w.r("anvi-gen-contigs-database -f {input.fasta} \
                                          -o {output.db} \
                                          {params.kmer_size} \
                                          {params.split_length} \
                                          {params.project_name} \
                                          {params.contigs_fasta} \
                                          {params.skip_mindful_splitting} \
                                          {params.ignore_internal_stop_codons} \
                                          {params.external_gene_calls} \
                                          {params.skip_gene_calling} \
                                          {params.prodigal_translation_table} \
                                          {params.description} >> {log} 2>&1")


rule reformat_external_functions:
    ''' Remove any gene that is not included in the contigs database due to fasta reformt with --min-len'''
    version: 1.0
    log: os.path.join(dirs_dict["LOGS_DIR"], "{group}-reformat_external_functions.log")
    input: unpack(M.get_input_for_reformat_external_functions)
    output: os.path.join(dirs_dict['FASTA_DIR'], "{group}", "{group}" + "-gene-functional-annotation.txt")
    threads: M.T('reformat_external_functions')
    resources: nodes = M.T('reformat_external_functions')
    run:
        import pandas as pd
        gene_functional_annotation_file = input.gene_functional_annotation
        external_gene_calls_file = input.external_gene_calls

        gene_functional_annotation = pd.read_csv(gene_functional_annotation_file, sep='\t', index_col=0)
        external_gene_calls = pd.read_csv(external_gene_calls_file, sep='\t', index_col=0)

        genes_in_both = [g for g in gene_functional_annotation.index if g in external_gene_calls.index]
        genes_only_in_functional_annotation_file = [g for g in gene_functional_annotation.index if g not in external_gene_calls.index]

        if genes_only_in_functional_annotation_file:
            warning = "When configuring gene functional annotations for %s\n" % wildcards.group + \
                      "we noticed that the following gene caller ids are not in your contigs database: \n'%s'.\n" % ', '.join(genes_only_in_functional_annotation_file) + \
                      "These are probably missing because you used the --min-len parameter when you ran\n" + \
                      "anvi-script-reformat-fasta. And hence, there is probably no reason for concern.\n"
            shell("echo -e '%s' > %s" % (warning, log))

        new_gene_functional_annotation = gene_functional_annotation.loc[genes_in_both,]
        new_gene_functional_annotation.to_csv(output[0], sep='\t')


rule import_external_functions:
    version: 1.0
    log: os.path.join(dirs_dict["LOGS_DIR"], "{group}-import_external_functions.log")
    input: unpack(M.get_input_for_import_external_functions)
    output:
        functions_imported = touch(os.path.join(dirs_dict["CONTIGS_DIR"], "{group}-external-functions-imported.done")),
    threads: M.T('import_external_functions')
    resources: nodes = M.T('import_external_functions')
    shell: "anvi-import-functions -c {input.contigs} -i {input.gene_functional_annotation} >> {log} 2>&1"


rule export_gene_calls_for_centrifuge:
    ''' Export gene calls and use for centrifuge'''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-export_gene_calls_for_centrifuge.log"
    # marking the input as ancient in order to ignore timestamps.
    input: ancient(M.get_contigs_db_path())
    # output is temporary. No need to keep this file.
    output: temp(dirs_dict["CONTIGS_DIR"] + "/{group}-gene-calls.fa")
    threads: M.T('export_gene_calls_for_centrifuge')
    resources: nodes = M.T('export_gene_calls_for_centrifuge'),
    shell: "anvi-get-sequences-for-gene-calls -c {input} -o {output} >> {log} 2>&1"


rule centrifuge:
    ''' Run centrifuge on the exported gene calls of the contigs.db'''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-centrifuge.log"
    input: rules.export_gene_calls_for_centrifuge.output
    output:
        hits = dirs_dict["CONTIGS_DIR"] + "/{group}-centrifuge_hits.tsv",
        report = dirs_dict["CONTIGS_DIR"] + "/{group}-centrifuge_report.tsv"
    params: db=M.get_param_value_from_config(["centrifuge", "db"])
    threads: M.T('centrifuge')
    resources: nodes = M.T('centrifuge'),
    shell: w.r("centrifuge -f \
                           -x {params.db} \
                           {input} \
                           -S {output.hits} \
                           --report-file {output.report} \
                           --threads {threads} >> {log} 2>&1")


rule anvi_import_taxonomy_for_genes:
    ''' Run anvi-import-taxonomy-for-genes'''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_import_taxonomy_for_genes.log"
    input:
        hits = rules.centrifuge.output.hits,
        report = rules.centrifuge.output.report,
        # marking the contigs.db as ancient in order to ignore timestamps.
        contigs = ancient(M.get_contigs_db_path())
    # using a flag file because no file is created by this rule.
    # for more information see:
    # http://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#flag-files
    output: touch(dirs_dict["CONTIGS_DIR"] + "/{group}-anvi_anvi_import_taxonomy_for_genes.done")
    params: parser = "centrifuge"
    threads: M.T('anvi_import_taxonomy_for_genes')
    resources: nodes = M.T('anvi_import_taxonomy_for_genes'),
    shell: w.r("anvi-import-taxonomy-for-genes -c {input.contigs} \
                                               -i {input.report} {input.hits} \
                                               -p {params.parser} >> {log} 2>&1")


rule anvi_run_hmms:
    """ Run anvi-run-hmms"""
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_run_hmms.log"
    # marking the input as ancient in order to ignore timestamps.
    input: ancient(M.get_contigs_db_path())
    # using a snakemake flag file as an output since no file is generated
    # by the rule.
    output: touch(dirs_dict["CONTIGS_DIR"] + "/anvi_run_hmms-{group}.done")
    params:
        installed_hmm_profile = M.get_rule_param("anvi_run_hmms", "--installed-hmm-profile"),
        hmm_profile_dir = M.get_rule_param("anvi_run_hmms", "--hmm-profile-dir"),
	    also_scan_trnas = M.get_rule_param("anvi_run_hmms", "--also-scan-trnas"),
    threads: M.T('anvi_run_hmms')
    resources: nodes = M.T('anvi_run_hmms'),
    shell: w.r("anvi-run-hmms -c {input} \
                              -T {threads} \
                              {params.hmm_profile_dir} \
                              {params.installed_hmm_profile} \
			      {params.also_scan_trnas} >> {log} 2>&1")


rule anvi_run_pfams:
    version: 1.0
    log: os.path.join(dirs_dict["LOGS_DIR"], "{group}-anvi_run_pfams.log")
    input: ancient(M.get_contigs_db_path())
    output: touch(os.path.join(dirs_dict["CONTIGS_DIR"], "anvi_run_pfams-{group}.done"))
    params:
        pfam_data_dir = M.get_rule_param('anvi_run_pfams', '--pfam-data-dir')
    threads: M.T('anvi_run_pfams')
    resources: nodes = M.T('anvi_run_pfams')
    shell: " anvi-run-pfams -c {input} {params.pfam_data_dir} -T {threads} >> {log} 2>&1"


rule anvi_run_ncbi_cogs:
    version: anvio.__contigs__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_run_ncbi_cogs.log"
    input:
        contigs = ancient(M.get_contigs_db_path())
    output: touch(dirs_dict["CONTIGS_DIR"] + "/anvi_run_ncbi_cogs-{group}.done")
    params:
        # anvi-run-ncbi-cogs params. See anvi-run-ncbi-cogs help menu for more info.
        cog_data_dir = M.get_rule_param('anvi_run_ncbi_cogs', '--cog-data-dir'),
        sensitive = M.get_rule_param('anvi_run_ncbi_cogs', '--sensitive'),
        temporary_dir_path = M.get_rule_param('anvi_run_ncbi_cogs', '--temporary-dir-path'),
        search_with = M.get_rule_param('anvi_run_ncbi_cogs', '--search-with')
    threads: M.T('anvi_run_ncbi_cogs')
    resources: nodes = M.T('anvi_run_ncbi_cogs'),
    shell: w.r("""anvi-run-ncbi-cogs -c {input.contigs} \
                                     -T {threads} \
                                     {params.cog_data_dir} \
                                     {params.sensitive} \
                                     {params.temporary_dir_path} \
                                     {params.search_with} >> {log} 2>&1""")


rule anvi_run_scg_taxonomy:
    version: anvio.__contigs__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_run_scg_taxonomy.log"
    input:
        # make sure HMMs were done before running this rule
        hmms_done = ancient(dirs_dict["CONTIGS_DIR"] + "/anvi_run_hmms-{group}.done"),
        contigs = ancient(M.get_contigs_db_path())
    output: touch(dirs_dict["CONTIGS_DIR"] + "/anvi_run_scg_taxonomy-{group}.done")
    params:
        scg_taxonomy_data_dir = M.get_rule_param('anvi_run_scg_taxonomy', '--scgs-taxonomy-data-dir'),
    threads: M.T('anvi_run_scg_taxonomy')
    resources: nodes = M.T('anvi_run_scg_taxonomy'),
    shell: w.r("""anvi-run-scg-taxonomy -c {input.contigs} \
                                        -T {threads} \
                                        {params.scg_taxonomy_data_dir} >> {log} 2>&1""")


rule anvi_run_trna_scan:
    version: anvio.__contigs__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_run_trna_scan.log"
    input:
        contigs = ancient(M.get_contigs_db_path())
    output: touch(dirs_dict["CONTIGS_DIR"] + "/anvi_run_trna_scan-{group}.done")
    params:
        # FIXME not sure how to handle optional --trna-hits-file parameter
        #trna_scan_hits_file = M.get_rule_param('anvi_run_trna_scan', '--trna-hits-file'),
        trna_cutoff_score = M.get_rule_param('anvi_run_trna_scan', '--trna-cutoff-score'),
    threads: M.T('anvi_run_trna_scan')
    resources: nodes = M.T('anvi_run_trna_scan'),
    shell: w.r("""anvi-scan-trnas -c {input.contigs} \
                                  -T {threads} \
                                  {params.trna_cutoff_score} >> {log} 2>&1""")


rule anvi_get_sequences_for_gene_calls:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_get_sequences_for_gene_calls.log"
    input: ancient(M.get_contigs_db_path())
    output: temp(dirs_dict["CONTIGS_DIR"] + "/{group}-contigs-aa-sequences.fa")
    threads: M.T('anvi_get_sequences_for_gene_calls')
    resources: nodes = M.T('anvi_get_sequences_for_gene_calls')
    shell: "anvi-get-sequences-for-gene-calls -c {input} -o {output} --get-aa-sequences >> {log} 2>&1"


rule emapper:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-emapper.log"
    input: rules.anvi_get_sequences_for_gene_calls.output
    output: dirs_dict["CONTIGS_DIR"] + "/{group}-contigs.emapper.annotations"
    # TODO: add other emapper params
    params:
        contigs_path_without_extension = dirs_dict["CONTIGS_DIR"] + "/{group}-contigs",
        path_to_emapper_dir = M.get_param_value_from_config(['emapper', 'path_to_emapper_dir']),
        database = M.get_rule_param("emapper", "--database"),
        usemem = M.get_rule_param("emapper", "--usemem"),
        override = M.get_rule_param("emapper", "--override")
    threads: M.T('emapper')
    resources: nodes = M.T('emapper')
    run:
        # running emapper
        shell("python2.7 {params.path_to_emapper_dir}/emapper.py -i {input} --output {params.contigs_path_without_extension} " + \
                  "--cpu {threads} {params.database} {params.usemem} {params.override} >> {log} 2>&1")


rule anvi_script_run_eggnog_mapper:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_script_run_eggnog_mapper.log"
    input:
        eggnog_output = rules.emapper.output,
        contigs = dirs_dict["CONTIGS_DIR"] + "/{group}-contigs.db"
    output:
        touch(dirs_dict["CONTIGS_DIR"] + "/{group}-anvi_script_run_eggnog_mapper.done")
    params:
        use_version = M.get_rule_param("anvi_script_run_eggnog_mapper", "--use-version")
    threads: M.T('anvi_script_run_eggnog_mapper')
    resources: nodes = M.T('anvi_script_run_eggnog_mapper')
    run:
        # Adding a 'g' prefix before every gene id (the anvi'o emapper driver requires this)
        shell("sed 's/^[0-9]/g&/' {input.eggnog_output} > {input.eggnog_output}.temp 2>{log}")

        shell("anvi-script-run-eggnog-mapper -c {input.contigs} --annotation {input.eggnog_output}.temp {params.use_version}  >> {log} 2>&1")

        shell("rm {input.eggnog_output}.temp 2>{log}")


# generate external genomes storage
# this rule is only used when this workflow is inherited (e.g. by pangenomics or phylogenomics workflows)
rule gen_external_genome_file:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "" + "/gen_external_genome_file.log"
    input:
        annotation_done = expand(dirs_dict['CONTIGS_DIR'] + "/{sample}-annotate_contigs_database.done", sample=M.group_names),
        contigs_dbs = expand(dirs_dict["CONTIGS_DIR"] + "/{sample}-contigs.db", sample=M.group_names)
    output: M.external_genomes_file
    threads: M.T("gen_external_genome_file")
    resources: nodes = M.T("gen_external_genome_file")
    run:
        with open(output[0], 'w') as f:
            f.write("name\tcontigs_db_path\n")
            for c in input.contigs_dbs:
                f.write("%s\t%s\n" % (c.split(".")[0].split("/")[-1].replace('-','_'), c))


if 'workflows/contigs' in workflow.included[0]:
    # check if all program dependencies are met. for this line to be effective,
    # there should be an initial dry run step (which is the default behavior of
    # the `WorkflowSuperClass`, so you are most likely covered).
    M.check_workflow_program_dependencies(workflow)
