# -*- coding: utf-8

'''
'''

import os
import anvio
import argparse
import pandas as pd
import anvio.utils as u
import anvio.workflows as w
import anvio.dbops as dbops
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError
from anvio.workflows.metapan import MetaPangenomicsWorkflow
from anvio.tables.miscdata import TableForLayerAdditionalData

__author__ = "Alon Shaiber"
__copyright__ = "Copyright 2017, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "Alon Shaiber"
__email__ = "alon.shaiber@gmail.com"

run = terminal.Run()
progress = terminal.Progress()

slave_mode= False if 'workflows/metapangenomics' in workflow.included[0] else True

if not slave_mode:
    # it is important that this comes before the include statement since
    # M is defined here and will also be used in the contigs workflow
    M = MetaPanWorkflow(argparse.Namespace(config=config))
    M.init()
    dirs_dict = M.dirs_dict
    # include the workflow dependencies
    # TODO: a better way to do this would be to have a dependencies variable under self
    # and then to include all dependencies
    include: w.get_workflow_snake_file_path('pangenomics')
    include: w.get_workflow_snake_file_path('metagenomics')
    include: w.get_workflow_snake_file_path('contigs')


# loading the metapangenome_fastas.txt file
metapangenome_fastas_txt = M.get_param_value_from_config(['metapangenome_fastas_txt'])
metapangenome_fastas_information = pd.read_csv(metapangenome_fastas_txt, sep='\t', index_col=False)

# FIXME: sanity check of every input file should be done in the backend, and not in
#        the snakemake recipe.
if 'sample' not in samples_information.columns.values:
    raise ConfigError("You know what. This '%s' file does not look anything like\
                       a samples file." % samples_txt_file)
# get a list of the sample names
sample_names = list(samples_information['sample'])

references_mode = M.get_param_value_from_config('references_mode', repress_default=True)
fasta_txt_file = M.get_param_value_from_config('fasta_txt', repress_default=True)
if references_mode:
    try:
        filesnpaths.is_file_exists(fasta_txt_file)
    except FilesNPathsError as e:
        raise ConfigError('In references mode you must supply a fasta_txt file.')

if not references_mode:
    # if it is reference mode then the group names have been assigned in the contigs Snakefile
    # if it is not reference mode and no groups are supplied in the samples_txt then group names are sample names
    group_names = sample_names

if fasta_txt_file and not references_mode:
    raise ConfigError("In order to use reference fasta files you must set\
                   \"'references_mode': true\" in your config file, yet\
                   you didn't, but at the same time you supplied the following\
                   fasta_txt: %s. So we don't know what to do with this\
                   fasta_txt" % fasta_txt_file)

# Collecting information regarding groups.
if "group" in samples_information.columns:
    # if groups were specified then members of a groups will be co-assembled.
    group_names = list(samples_information['group'].unique())
    # creating a dictionary with groups as keys and number of samples in
    # the groups as values
    group_sizes = samples_information['group'].value_counts().to_dict()

    if references_mode:
        # sanity check to see that groups specified in samples.txt match
        # the names of fasta.
        mismatch = set(group_names) - set(fasta_information.keys())
        if mismatch:
            raise ConfigError("Group names specified in the samples.txt \
                               file must match the names of fasta \
                               in the fasta.txt file. These are the \
                               mismatches: %s" % mismatch)
        groups_in_fasta_information_but_not_in_samples_txt = set(fasta_information.keys()) - set(group_names)
        if groups_in_fasta_information_but_not_in_samples_txt:
            run.warning('The following group names appear in your fasta_txt\
                         but do not appear in your samples_txt. Maybe this is\
                         ok with you, but we thought you should know. This means\
                         that the metagenomics workflow will simply ignore these\
                         groups.')

else:
    if references_mode:
        # if the user didn't provide a group column in the samples.txt,
        # in references mode the default is 'all_against_all'.
        run.warning("No groups were provided in your samples_txt,\
                     hence 'all_against_all' mode has been automatically\
                     set to True.")
        M.set_config_param('all_against_all', True)
    else:
        # if no groups were specified then each sample would be assembled
        # separately
        run.warning("No groups were specified in your samples_txt. This is fine.\
                     But we thought you should know. Any assembly will be performed\
                     on individual samples (i.e. NO co-assembly).")
        samples_information['group'] = samples_information['sample']
        group_names = list(sample_names)
        group_sizes = dict.fromkeys(group_names,1)

if M.get_param_value_from_config('all_against_all', repress_default=True):
    # in all_against_all, the size of each group is as big as the number
    # of samples.
    group_sizes = dict.fromkeys(group_names,len(sample_names))


if not references_mode and not (M.get_param_value_from_config(['anvi_script_reformat_fasta','run']) == True):
    # in assembly mode (i.e. not in references mode) we always have
    # to run reformat_fasta. The only reason for this is that
    # the megahit output is temporary, and if we dont run
    # reformat_fasta we will delete the output of meghit at the
    # end of the workflow without saving a copy.
    raise ConfigError("You can't skip reformat_fasta in assembly mode "\
                        "please change your config.json file")

# by default fastq files will be zipped after qc is done
run_gzip_fastqs = M.get_param_value_from_config(['gzip_fastqs', 'run']) == True
#
run_qc = M.get_param_value_from_config(['iu_filter_quality_minoche', 'run']) == True

run_import_percent_of_reads_mapped = M.get_param_value_from_config(['import_percent_of_reads_mapped', 'run']) == True

run_idba_ud = M.get_param_value_from_config(['idba_ud', 'run'])

run_megahit = M.get_param_value_from_config(['megahit', 'run']) == True

if not references_mode and not (run_idba_ud or run_megahit):
    # Sanity check for assembly mode
    # Make sure that user specified an assembler
    raise ConfigError("You didn't specify any assembler for your metagenomes, and yet\
                       your config file shows you are not using references_mode. If you\
                       already have your assemblies or reference fasta files, then maybe\
                       you need to add '\"references_mode\": true', and a '\"fasta_txt\": \
                       \"name-of-your-fasta-txt-file\"' to your config file. Otherwise,\
                       if you do mean to use assembly, please set 'run: true' for one of the,\
                       following available assemblers: %s" % ', '.join(available_assemblers))

rule metagenomics_workflow_target_rule:
    '''
        The target rule for the workflow.

        The final product of the workflow is an anvi'o merged profile directory
        for each group
    '''
    input: expand("{DIR}/{group}/PROFILE.db", DIR=dirs_dict["MERGE_DIR"], group=group_names),
           contigs_annotated = expand(dirs_dict["CONTIGS_DIR"] + "/{group}-annotate_contigs_database.done", group=group_names),
           qc_report = dirs_dict["QC_DIR"] + "/qc-report.txt" if run_qc else expand(dirs_dict["CONTIGS_DIR"] + "/{group}-contigs.db", group=group_names)


rule iu_gen_configs:
    '''
        Generating a config file for each sample.

        Notice that this step is ran only once and generates the config files for all samples
    '''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/iu_gen_configs.log"
    # the input file is marked as 'ancient' so snakemake wouldn't run it
    # just because a new path-to-raw-fastq-files.txt file was created.
    input: ancient(samples_txt_file)
    output: expand("{DIR}/{sample}.ini", DIR=dirs_dict["QC_DIR"], sample=sample_names)
    params:
        dir=dirs_dict["QC_DIR"],
        r1_prefix = M.get_rule_param("iu_gen_configs", "--r1-prefix"),
        r2_prefix = M.get_rule_param("iu_gen_configs", "--r2-prefix")
    threads: M.T('iu_gen_configs')
    resources: nodes = M.T('iu_gen_configs'),
    shell: "iu-gen-configs {input} -o {params.dir} {params.r2_prefix} {params.r1_prefix} >> {log} 2>&1"


def get_raw_fastq(wildcards):
    ''' return a dict with the path to the raw fastq files'''
    r1 = list(samples_information[samples_information["sample"] == wildcards.sample]['r1'])[0].split(',')
    r2 = list(samples_information[samples_information["sample"] == wildcards.sample]['r2'])[0].split(',')
    return {'r1': r1, 'r2': r2}


def get_fastq(wildcards):
    ''' return the pair of compressed fastq files for a sample.

        There are two types of sources for the fastq:
            1. The output of QC.
            2. From the specified paths in samples.txt (in the case the user
                                                        chose to skip QC).
        This helper function returns the appropriate paths according to the
        config file.
    '''
    if run_qc:
        # by default, use the output of the qc
        if run_gzip_fastqs:
            r1 = expand("{DIR}/{sample}-QUALITY_PASSED_R1.fastq.gz", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
            r2 = expand("{DIR}/{sample}-QUALITY_PASSED_R2.fastq.gz", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
        else:
            r1 = expand("{DIR}/{sample}-QUALITY_PASSED_R1.fastq", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
            r2 = expand("{DIR}/{sample}-QUALITY_PASSED_R2.fastq", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
        d = {'r1': r1, 'r2': r2}
    else:
        # if no qc is requested, use raw input
        d = get_raw_fastq(wildcards)
    return d


def input_for_qc(wildcards):
    ''' return a dict with input for qc rule'''
    d = {'ini': ancient(dirs_dict["QC_DIR"] + "/%s.ini" % wildcards.sample)}
    d.update(get_raw_fastq(wildcards))
    return d


rule iu_filter_quality_minoche:
    ''' Run QC using iu-filter-quality-minoche '''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{sample}-iu_filter_quality_minoche.log"
    # making the config file as "ancient" so QC wouldn't run just because
    # a new config file was produced.
    input: unpack(input_for_qc)
    output:
        r1 = dirs_dict["QC_DIR"] + "/{sample}-QUALITY_PASSED_R1.fastq",
        r2 = dirs_dict["QC_DIR"] + "/{sample}-QUALITY_PASSED_R2.fastq",
        stats = dirs_dict["QC_DIR"] + "/{sample}-STATS.txt"
    params:
        ignore_deflines = M.get_rule_param("iu_filter_quality_minoche", "--ignore-deflines"),
        visualize_quality_curves = M.get_rule_param("iu_filter_quality_minoche", "--visualize-quality-curves"),
        limit_num_pairs = M.get_rule_param("iu_filter_quality_minoche", "--limit-num-pairs"),
        print_qual_scores = M.get_rule_param("iu_filter_quality_minoche", "--print-qual-scores"),
        store_read_fate = M.get_rule_param("iu_filter_quality_minoche", "--store-read-fate")
    threads: M.T('iu_filter_quality_minoche')
    resources: nodes = M.T('iu_filter_quality_minoche'),
    shell:
        """
            iu-filter-quality-minoche {input.ini} {params.store_read_fate}\
                                      {params.print_qual_scores} {params.limit_num_pairs}\
                                      {params.visualize_quality_curves} {params.ignore_deflines} >> {log} 2>&1
        """


rule gen_qc_report:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/gen_qc_report.log"
    input: expand(dirs_dict["QC_DIR"] + "/{sample}-STATS.txt", sample=sample_names)
    output: dirs_dict["QC_DIR"] + "/qc-report.txt"
    threads: M.T('gen_qc_report')
    resources: nodes = M.T('gen_qc_report')
    run:
        report_dict = {}
        report_column_headers = ['number of pairs analyzed',
             'total pairs passed',
             'total pairs passed (percent of all pairs)',
             'total pair_1 trimmed',
             'total pair_1 trimmed (percent of all passed pairs)',
             'total pair_2 trimmed',
             'total pair_2 trimmed (percent of all passed pairs)',
             'total pairs failed',
             'total pairs failed (percent of all pairs)',
             'pairs failed due to pair_1',
             'pairs failed due to pair_1 (percent of all failed pairs)',
             'pairs failed due to pair_2',
             'pairs failed due to pair_2 (percent of all failed pairs)',
             'pairs failed due to both',
             'pairs failed due to both (percent of all failed pairs)',
             'FAILED_REASON_P',
             'FAILED_REASON_P (percent of all failed pairs)',
             'FAILED_REASON_N',
             'FAILED_REASON_N (percent of all failed pairs)',
             'FAILED_REASON_C33',
             'FAILED_REASON_C33 (percent of all failed pairs)']
        for filename in input:
            sample = os.path.basename(filename).split("-STATS.txt")[0]
            report_dict[sample] = dict.fromkeys(report_column_headers, 0)
            with open(filename,'r') as f:
                firstline = True
                for line in f.readlines():
                    s1 = line.split(':')
                    numeric_header = s1[0].strip()
                    s2 = s1[1].split('(')
                    numeric = s2[0].strip()
                    report_dict[sample][numeric_header] = numeric
                    if not firstline:
                        s3 = s2[1].split(' ')
                        percent = s3[0].strip('%')
                        percent_header = numeric_header + " (percent " + " ".join(s3[1:])
                        percent_header = percent_header.strip()
                        report_dict[sample][percent_header] = percent
                    else:
                        firstline = False
        u.store_dict_as_TAB_delimited_file(report_dict, output[0], headers= ["sample"] + report_column_headers)


def fastq_input_for_fq2fa(wildcards):
    ''' return the pair of uncompressed fastq files for a sample.

        See the documentation for get_fastq to understand why we need this.
        This function is different from get_fastq because in this case,
        we use the uncompressed output of QC instead of the compressed,
        because fq2fa expects uncompressed files.
    '''
    d = {}
    if run_qc:
        # by default, use the output of the qc
        d['r1'] = expand("{DIR}/{sample}-QUALITY_PASSED_R1.fastq", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
        d['r2'] = expand("{DIR}/{sample}-QUALITY_PASSED_R2.fastq", DIR=dirs_dict["QC_DIR"], sample=wildcards.sample)
    else:
        # if no qc is requested, use raw input
        d = get_raw_fastq(wildcards)
    return d


def input_for_fq2fa(wildcards):
    ''' this function is just to figure out if we need to gunzip the fastq or not'''
    d = fastq_input_for_fq2fa(wildcards)
    # Checking if any of the input files are compressed
    # this should only happen if QC is not performed in this snakemake session
    # because if QC is performed in this session then the output of iu-filter-quality-minoche
    # is not compressed and snakemake will schedule fq2fa rule before gzip rule
    files_that_end_with_gz = [f for f in [list(d['r1']), list(d['r2'])] if f[0].endswith('.gz')]
    if len(files_that_end_with_gz) == 1:
        raise ConfigError("Something seems very bad: one of the pair fastq files\
                           is compressed and the other one is not. This is the \
                           compressed one: %s" % files_that_end_with_gz[0])
    elif len(files_that_end_with_gz) == 2:
        run.warning("The following fastq files are compressed and will now \
                     be uncompressed using gunzip: %s." % files_that_end_with_gz)
        d['r1'] = dirs_dict["QC_DIR"] + "/%s.r1.fastq" % wildcards.sample
        d['r2'] = dirs_dict["QC_DIR"] + "/%s.r2.fastq" % wildcards.sample

    return d


rule gunzip_for_fq2fa:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{sample}-gunzip_for_fq2fa.log"
    input: unpack(fastq_input_for_fq2fa)
    output:
        r1 = temp(dirs_dict["QC_DIR"] + "/{sample}.r1.fastq"),
        r2 = temp(dirs_dict["QC_DIR"] + "/{sample}.r2.fastq")
    threads: w.T(config, 'gunzip_for_fq2fa', 1)
    resources: nodes = w.T(config, 'gunzip_for_fq2fa', 1)
    shell:
        """
            gunzip < {input.r1} > {output.r1}
            gunzip < {input.r2} > {output.r2}
        """


rule fq2fa:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{sample}-fq2fa.log"
    input: unpack(input_for_fq2fa)
    output: temp(dirs_dict["QC_DIR"] + "/{sample}-merged-reads.fa")
    threads: M.T('fq2fa')
    resources: nodes = M.T('fq2fa')
    shell: "fq2fa --merge {input} {output} >> {log} 2>&1"


rule merge_fastas_for_co_assembly:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-merge_fastas_for_co_assembly.log"
    input: lambda wildcards: expand("{DIR}/{sample}-merged-reads.fa", DIR=dirs_dict["QC_DIR"], sample=list(samples_information[samples_information["group"] == wildcards.group]["sample"]))
    output: temp(dirs_dict["QC_DIR"] + "/{group}-merged.fa")
    threads: M.T('merge_fastas_for_co_assembly')
    resources: nodes = M.T('merge_fastas_for_co_assembly')
    shell: "cat {input} > {output}"


# if idba_ud is used then we need to ensure the fastq files
# will only be compressed AFTER fq2fa is done
flag_file_for_gzip = dirs_dict["QC_DIR"] + "/{sample}-QUALITY_PASSED_{R}.fastq"
if run_idba_ud:
    flag_file_for_gzip = rules.fq2fa.output

rule gzip_fastqs:
    ''' Compressing the quality controlled fastq files'''
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{sample}-{R}-gzip.log"
    input:
        fastq = dirs_dict["QC_DIR"] + "/{sample}-QUALITY_PASSED_{R}.fastq",
        run_gzip_flag = flag_file_for_gzip
    output: dirs_dict["QC_DIR"] + "/{sample}-QUALITY_PASSED_{R}.fastq.gz"
    threads: M.T('gzip_fastqs')
    resources: nodes = M.T('gzip_fastqs'),
    shell: "gzip {input.fastq} >> {log} 2>&1"


if run_idba_ud == True:
    rule idba_ud:
        version: 1.0
        log: dirs_dict["LOGS_DIR"] + "/{group}-idba_ud.log"
        input:
            fasta = dirs_dict["QC_DIR"] + "/{group}-merged.fa"
        output:
            temp_dir = temp(dirs_dict["FASTA_DIR"] + "/{group}_TEMP"),
            contigs = temp(dirs_dict["FASTA_DIR"] + "/{group}/final.contigs.fa")
        params:
            mink = M.get_rule_param("idba_ud", "--mink"),
            maxk = M.get_rule_param("idba_ud", "--maxk"),
            step = M.get_rule_param("idba_ud", "--step"),
            inner_mink = M.get_rule_param("idba_ud", "--inner_mink"),
            inner_step = M.get_rule_param("idba_ud", "--inner_step"),
            prefix = M.get_rule_param("idba_ud", "--prefix"),
            min_count = M.get_rule_param("idba_ud", "--min_count"),
            min_support = M.get_rule_param("idba_ud", "--min_support"),
            seed_kmer = M.get_rule_param("idba_ud", "--seed_kmer"),
            min_contig = M.get_rule_param("idba_ud", "--min_contig"),
            similar = M.get_rule_param("idba_ud", "--similar"),
            max_mismatch = M.get_rule_param("idba_ud", "--max_mismatch"),
            min_pairs = M.get_rule_param("idba_ud", "--min_pairs"),
            no_bubble = M.get_rule_param("idba_ud", "--no_bubble"),
            no_local = M.get_rule_param("idba_ud", "--no_local"),
            no_coverage = M.get_rule_param("idba_ud", "--no_coverage"),
            no_correct = M.get_rule_param("idba_ud", "--no_correct"),
            pre_correction = M.get_rule_param("idba_ud", "--pre_correction"),
        threads: M.T('idba_ud')
        resources: nodes = M.T('idba_ud')
        run:
            cmd = "idba_ud -o {output.temp_dir} --read {input.fasta} --num_threads {threads} " + \
                  "{params.mink} {params.maxk} {params.step} {params.inner_mink} " + \
                  "{params.inner_step} {params.prefix} {params.min_count} " + \
                  "{params.min_support} {params.seed_kmer} {params.min_contig} " + \
                  "{params.similar} {params.max_mismatch} {params.min_pairs} " + \
                  "{params.no_bubble} {params.no_local} {params.no_coverage} " + \
                  "{params.no_correct} {params.pre_correction} >> {log} 2>&1"
            shell(cmd)
            shell("mv {output.temp_dir}/contig.fa {output.contigs} >> {log} 2>&1")


def input_for_megahit(wildcards):
    ''' Creating a dictionary containing the path to input fastq file.
        The reason we can't use get_fastq is because we need to get
        the fastq files for all samples that belong to a group, and
        get_fastq only gives the pair of fastq file for one sample
    '''
    if run_qc:
        # by default, use the output of the qc
        if run_gzip_fastqs:
            r1 = expand("{DIR}/{sample}-QUALITY_PASSED_R1.fastq.gz", DIR=dirs_dict["QC_DIR"], sample=list(samples_information[samples_information["group"] == wildcards.group]["sample"]))
            r2 = expand("{DIR}/{sample}-QUALITY_PASSED_R2.fastq.gz", DIR=dirs_dict["QC_DIR"], sample=list(samples_information[samples_information["group"] == wildcards.group]["sample"]))
        else:
            r1 = expand("{DIR}/{sample}-QUALITY_PASSED_R1.fastq", DIR=dirs_dict["QC_DIR"], sample=list(samples_information[samples_information["group"] == wildcards.group]["sample"]))
            r2 = expand("{DIR}/{sample}-QUALITY_PASSED_R2.fastq", DIR=dirs_dict["QC_DIR"], sample=list(samples_information[samples_information["group"] == wildcards.group]["sample"]))
        d = {'r1': r1, 'r2': r2}
    else:
        r1 = list(samples_information[samples_information["group"] == wildcards.group]['r1'])
        r2 = list(samples_information[samples_information["group"] == wildcards.group]['r2'])
    return {'r1': r1, 'r2': r2}


if M.get_param_value_from_config(['megahit', 'run']) == True:
    rule megahit:
        '''
            Assembling fastq files using megahit.

            All files created by megahit are stored in a temporary folder,
            and only the fasta file is kept for later analysis.
        '''
        version: 1.0
        log: dirs_dict["LOGS_DIR"] + "/{group}-megahit.log"
        input: unpack(input_for_megahit)
        params:
            # the minimum length for contigs (smaller contigs will be discarded)
            min_contig_len = M.get_rule_param("megahit", "--min-contig-len"),
            min_count = M.get_rule_param("megahit", "--min-count"),
            k_min = M.get_rule_param("megahit", "--k-min"),
            k_max = M.get_rule_param("megahit", "--k-max"),
            k_step = M.get_rule_param("megahit", "--k-step"),
            k_list = M.get_rule_param("megahit", "--k-list"),
            no_mercy = M.get_rule_param("megahit", "--no-mercy"),
            no_bubble = M.get_rule_param("megahit", "--no-bubble"),
            merge_level = M.get_rule_param("megahit", "--merge-level"),
            prune_level = M.get_rule_param("megahit", "--prune-level"),
            prune_depth = M.get_rule_param("megahit", "--prune-depth"),
            low_local_ratio = M.get_rule_param("megahit", "--low-local-ratio"),
            max_tip_len = M.get_rule_param("megahit", "--max-tip-len"),
            no_local = M.get_rule_param("megahit", "--no-local"),
            kmin_1pass = M.get_rule_param("megahit", "--kmin-1pass"),
            presets = M.get_rule_param("megahit", "--presets"),
            memory = M.get_rule_param("megahit", "--memory"),
            mem_flag = M.get_rule_param("megahit", "--mem-flag"),
            use_gpu = M.get_rule_param("megahit", "--use-gpu"),
            gpu_mem = M.get_rule_param("megahit", "--gpu-mem"),
            keep_tmp_files = M.get_rule_param("megahit", "--keep-tmp-files"),
            tmp_dir = M.get_rule_param("megahit", "--tmp-dir"),
            _continue = M.get_rule_param("megahit", "--continue"),
            verbose = M.get_rule_param("megahit", "--verbose"),
        # Notice that megahit requires a directory to be specified as
        # output. If the directory already exists then megahit will not
        # run. To avoid this, the for megahit is a temporary directory,
        # once megahit is done running then the contigs database is moved
        # to the final location.
        output:
            temp_dir = temp(dirs_dict["FASTA_DIR"] + "/{group}_TEMP"),
            contigs = temp(dirs_dict["FASTA_DIR"] + "/{group}/final.contigs.fa")
        threads: M.T('megahit')
        resources: nodes = M.T('megahit'),
        # Making this rule a shadow rule so all extra files created by megahit
        # are not retaineded (it is not enough to define the directory as temporary
        # because when failing in the middle of a run, snakemake doesn't delete directories)
        run:
            r1 = ','.join(input.r1)
            r2 = ','.join(input.r2)

            cmd = "megahit -1 %s -2 %s " % (r1, r2) + \
                "-o {output.temp_dir} " + \
                "-t {threads} " + \
                "{params.min_contig_len} {params.min_count} {params.k_min} " + \
                "{params.k_max} {params.k_step} {params.k_list} {params.no_mercy} " + \
                "{params.no_bubble} {params.merge_level} {params.prune_level} " + \
                "{params.prune_depth} {params.low_local_ratio} {params.max_tip_len} " + \
                "{params.no_local} {params.kmin_1pass} {params.presets} {params.memory} " + \
                "{params.mem_flag} {params.use_gpu} {params.gpu_mem} {params.keep_tmp_files} " + \
                "{params.tmp_dir} {params._continue} {params.verbose} >> {log} 2>&1"
            print("Running: %s" % cmd)
            shell(cmd)
            shell("mv {output.temp_dir}/final.contigs.fa {output.contigs} >> {log} 2>&1")



rule bowtie_build:
    """ Run bowtie-build on the contigs fasta"""
    # TODO: consider runnig this as a shadow rule
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-bowtie_build.log"
    input: get_fasta
    # I touch this file because the files created have different suffix
    output:
        o1 = expand(dirs_dict["MAPPING_DIR"] + "/{group}/{group}-contigs" + '.{i}.bt2', i=[1,2,3,4], group="{group}"),
        o2 = expand(dirs_dict["MAPPING_DIR"] + "/{group}/{group}-contigs" + '.rev.{i}.bt2', i=[1,2], group="{group}")
    params:
        prefix = dirs_dict["MAPPING_DIR"] + "/{group}/{group}-contigs"
    threads: M.T('bowtie_build')
    resources: nodes = M.T('bowtie_build'),
    shell: "bowtie2-build {input} {params.prefix} >> {log} 2>&1"


def input_for_bowtie(wildcards):
    '''Creating a dictionary containing the input files for bowtie.'''
    d = {'build_output': rules.bowtie_build.output}
    # add the fastq files paths to the dictionary:
    d.update(get_fastq(wildcards))
    return d


rule bowtie:
    """ Run mapping with bowtie2"""
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-{sample}-bowtie.log"
    input: unpack(input_for_bowtie)
    # setting the output as temp, since we only want to keep the bam file.
    output:
        sam = temp(dirs_dict["MAPPING_DIR"] + "/{group}/{sample}.sam")
    params:
        dir = dirs_dict["MAPPING_DIR"] + "/{sample}",
        bowtie_build_prefix = rules.bowtie_build.params.prefix,
        additional_params = M.get_param_value_from_config(["bowtie", "additional_params"])
    threads: M.T('bowtie')
    resources: nodes = M.T('bowtie'),
    shell:
        """
            bowtie2 --threads {threads} -x {params.bowtie_build_prefix} -1 {input.r1} -2 {input.r2} \
            {params.additional_params} -S {output.sam} >> {log} 2>&1
        """


rule samtools_view:
    """ sort sam file with samtools and create a RAW.bam file"""
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-{sample}-samtools_view.log"
    input: rules.bowtie.output.sam
    params: additional_params = M.get_param_value_from_config(["samtools_view", "additional_params"])
    # output as temp. we only keep the final bam file
    output: temp(dirs_dict["MAPPING_DIR"] + "/{group}/{sample}-RAW.bam")
    threads: M.T('samtools_view')
    resources: nodes = M.T('samtools_view'),
    shell: "samtools view -bS {input} -o {output} {params.additional_params} >> {log} 2>&1"


rule anvi_init_bam:
    """ run anvi-init-bam on RAW bam file to create a bam file ready for anvi-profile"""
    version: 1.0 # later we can decide if we want the version to use the version of anvi'o
    log: dirs_dict["LOGS_DIR"] + "/{group}-{sample}-anvi_init_bam.log"
    input: rules.samtools_view.output
    output:
        bam = dirs_dict["MAPPING_DIR"] + "/{group}/{sample}.bam",
        bai = dirs_dict["MAPPING_DIR"] + "/{group}/{sample}.bam.bai"
    threads: M.T('anvi_init_bam')
    resources: nodes = M.T('anvi_init_bam'),
    shell: "anvi-init-bam {input} -o {output.bam} >> {log} 2>&1"


sample_name = M.get_param_value_from_config(['anvi_profile', '--sample-name'], repress_default=True)
if sample_name != M.default_config['anvi_profile']['--sample-name'] and sample_name is not None:
    run.warning('You chose to set the "--sample-name" for your profile databases\
                 in the config file to %s. You are welcomed to do that, but at your own\
                 risk. Just so you know, by default the sample name would match\
                 the name defined either in the samples_txt, by choosing to provide\
                 a different name, it means that all your profile databases would have\
                 the same name, unless you incloded "{sample}" in the name you provided\
                 but even then, we did not test that option and we are not sure it would\
                 work...' % sample_name)


def get_cluster_contigs_param(wildcards):
    """ helper function to sort out whether to cluster contigs for a single profile database"""
    if M.get_param_value_from_config(['anvi_profile', '--cluster-contigs'], repress_default=True):
        run.warning('You chose to set the value for --cluster-contigs as %s. You can do that if you\
                     choose to, but just so you know, if you don\'t provide a value then \
                     the workflow would automatically cluster contigs for profile databases\
                     that are not merged (i.e. profiles that belong to groups of size 1).')
        cluster_contigs = M.get_rule_param('anvi_profile', '--cluster-contigs')
    else:
        # if profiling to individual assembly then clustering contigs
        # see --cluster-contigs in the help manu of anvi-profile
        cluster_contigs = '--cluster-contigs' if group_sizes[wildcards.group] == 1 else ''
    return cluster_contigs


rule anvi_profile:
    """ run anvi-profile on the bam file"""
    # setting the rule version to be as the version of the profile database of anvi'o
    version: anvio.__profile__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-{sample}-anvi_profile.log"
    input:
        bam = dirs_dict["MAPPING_DIR"] + "/{group}/{sample}.bam",
        # marking the contigs.db as ancient in order to ignore timestamps.
        contigs = ancient(dirs_dict["CONTIGS_DIR"] + "/{group}-contigs.db")
    output:
        profile = dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/PROFILE.db",
        runlog = dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/RUNLOG.txt"
    params:
        output_dir = dirs_dict["PROFILE_DIR"] + "/{group}/{sample}",
        cluster_contigs = get_cluster_contigs_param,
        sample_name = M.get_rule_param("anvi_profile", "--sample-name"),
        overwrite_output_destinations = M.get_rule_param("anvi_profile", "--overwrite-output-destinations"),
        report_variability_full = M.get_rule_param("anvi_profile", "--report-variability-full"),
        skip_SNV_profiling = M.get_rule_param("anvi_profile", "--skip-SNV-profiling"),
        profile_SCVs = M.get_rule_param("anvi_profile", "--profile-SCVs"),
        description = M.get_rule_param("anvi_profile", "--description"),
        skip_hierarchical_clustering = M.get_rule_param("anvi_profile", "--skip-hierarchical-clustering"),
        distance = M.get_rule_param("anvi_profile", "--distance"),
        linkage = M.get_rule_param("anvi_profile", "--linkage"),
        min_contig_length = M.get_rule_param("anvi_profile", "--min-contig-length"),
        min_mean_coverage = M.get_rule_param("anvi_profile", "--min-mean-coverage"),
        min_coverage_for_variability = M.get_rule_param("anvi_profile", "--min-coverage-for-variability"),
        contigs_of_interest = M.get_rule_param("anvi_profile", "--contigs-of-interest"),
        queue_size = M.get_rule_param("anvi_profile", "--queue-size"),
        write_buffer_size = M.get_rule_param("anvi_profile", "--write-buffer-size"),
    threads: M.T('anvi_profile')
    resources: nodes = M.T('anvi_profile'),
    shell:
        """
            anvi-profile -i {input.bam} -c {input.contigs} -o {params.output_dir} \
                             {params.cluster_contigs} {params.min_contig_length} \
                             {params.sample_name} -T {threads} {params.overwrite_output_destinations} \
                             {params.profile_SCVs} {params.report_variability_full} \
                             {params.skip_SNV_profiling} {params.description} \
                             {params.skip_hierarchical_clustering} {params.distance} \
                             {params.linkage} {params.min_mean_coverage} \
                             {params.min_coverage_for_variability} {params.contigs_of_interest} \
                             {params.queue_size} {params.write_buffer_size} >> {log} 2>&1
        """


def input_for_anvi_merge(wildcards):
    '''
        Create dictionary as input for rule anvi_merge.
        The reason we need a function as an input is to allow the user
        to choose between an option of an "all against all" vs. "normal"
        modes. See the documentation to learn more about the difference
        between these modes.
    '''

    if M.get_param_value_from_config(['all_against_all']):
        # If the user specified 'all against all' in the configs file
        # the end product would be a merge of all samples per group
        profiles = expand(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/PROFILE.db", sample=list(samples_information['sample']), group=wildcards.group)

    else:
        # The default behaviour is to only merge (and hence map and profile)
        # together samples that belong to the same group.
        profiles = expand(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/PROFILE.db", sample=list(samples_information[samples_information['group'] == wildcards.group]['sample']), group=wildcards.group)

    return profiles


def percent_of_reads_mapped_imported_flag_input(wildcards):
    '''
        this function creates the input to anvi_merge rule with
        regards to percent_of_reads_mapped_imported_flag_input.
    '''
    if M.get_param_value_from_config(['all_against_all']):
        percent_of_reads_mapped_imported_flag = expand(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/import_percent_of_reads_mapped.done", sample=list(samples_information['sample']), group=wildcards.group)

    else:
        percent_of_reads_mapped_imported_flag = expand(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/import_percent_of_reads_mapped.done", sample=list(samples_information[samples_information['group'] == wildcards.group]['sample']), group=wildcards.group)

    if not run_import_percent_of_reads_mapped:
        percent_of_reads_mapped_imported_flag = ancient(dirs_dict["CONTIGS_DIR"] + "/%s-contigs.db" % wildcards.group)

    return percent_of_reads_mapped_imported_flag


def create_fake_output_files(_message, output):
    # creating "fake" output files with an informative message for
    # user.
    for o in output:
        with open(o, 'w') as f:
            f.write(_message + '\n')


def remove_empty_profile_databases(profiles, group):
    '''remove profiles that recruited zero reads from the metagenome.'''

    empty_profiles = []
    progress.new("Checking for empty profile databases")
    for p in profiles:
        keys, data = TableForLayerAdditionalData(argparse.Namespace(profile_db=p)).get(['total_reads_mapped'])

        if not next(iter(data.values()))['total_reads_mapped']:
            # this profile is empty so we can't include it in the merged profile.
            empty_profiles.append(p)

    profiles = list(set(profiles) - set(empty_profiles))
    progress.end()

    if not profiles:
        # if there are no profiles to merge then notify the user
        run.warning('It seems that all your profiles are empty for the \
                     contigs database: %s.db. And so cannot be merged.' \
                     % group)

    run.info('Number of non-empty profile databases', len(profiles))
    run.info('Number of empty profile databases', len(empty_profiles))
    if len(empty_profiles) > 0:
        run.info('The following databases are empty: ', empty_profiles)

    return profiles


rule anvi_merge:
    '''
        Run create a merged profile database.

        If there are multiple profiles mapped to the same contigs database,
        then merges these profiles. For individual profile, creates a symlink
        to the profile database. The purpose is to have one folder in
        which for every contigs database there is a profile database (or
        a symlink to a profile database) that could be used together for
        anvi-interactive.
    '''
    version: anvio.__profile__version__
    log: dirs_dict["LOGS_DIR"] + "/{group}-anvi_merge.log"
    # The input are all profile databases that belong to the same group
    input:
        # marking the contigs.db as ancient in order to ignore timestamps.
        contigs = ancient(dirs_dict["CONTIGS_DIR"] + "/{group}-contigs.db"),
        profiles = input_for_anvi_merge,
        percent_of_reads_mapped_imported_flag = percent_of_reads_mapped_imported_flag_input
    output:
        profile = dirs_dict["MERGE_DIR"] + "/{group}/PROFILE.db",
        runlog = dirs_dict["MERGE_DIR"] + "/{group}/RUNLOG.txt"
    threads: M.T('anvi_merge')
    resources: nodes = M.T('anvi_merge'),
    params:
        output_dir = dirs_dict["MERGE_DIR"] + "/{group}",
        name = "{group}",
        profile_dir = dirs_dict["PROFILE_DIR"] + "/{group}",
        sample_name = M.get_rule_param("anvi_merge", "--sample-name"),
        description = M.get_rule_param("anvi_merge", "--description"),
        skip_hierarchical_clustering = M.get_rule_param("anvi_merge", "--skip-hierarchical-clustering"),
        enforce_hierarchical_clustering = M.get_rule_param("anvi_merge", "--enforce-hierarchical-clustering"),
        distance = M.get_rule_param("anvi_merge", "--distance"),
        linkage = M.get_rule_param("anvi_merge", "--linkage"),
        skip_concoct_binning = M.get_rule_param("anvi_merge", "--skip-concoct-binning"),
        overwrite_output_destinations = M.get_rule_param("anvi_merge", "--overwrite-output-destinations"),
    run:
        # using run instead of shell so we can choose the appropriate shell command.
        # In accordance with: https://bitbucket.org/snakemake/snakemake/issues/37/add-complex-conditional-file-dependency#comment-29348196

        # remove empty profile databases
        input.profiles = remove_empty_profile_databases(input.profiles, wildcards.group)

        if not input.profiles:
            # there are no profiles to merge.
            # this should only happen if all profiles were empty.
            _message = "Nothing to merge for %s. This should " \
                       "only happen if all profiles were empty " \
                       "(you can check the log file: {log} to see " \
                       "if that is indeed the case). " \
                       "This file was created just so that your workflow " \
                       "would continue with no error (snakemake expects " \
                       "to find these output files and if we don't create " \
                       "them, then it will be upset). As we see it, " \
                       "there is no reason to throw an error here, since " \
                       "you mapped your metagenome to some fasta files " \
                       "and you got your answer: whatever you have in " \
                       "your fasta file is not represented in your  " \
                       "metagenomes. Feel free to contact us if you think " \
                       "that this is our fault. sincerely, Meren Lab" \
                       % wildcards.group
            # creating the expected output files for the rule
            create_fake_output_files(_message, output)

        elif group_sizes[wildcards.group] == 1:
            # for individual assemblies, create a symlink to the profile database
            #shell("ln -s {params.profile_dir}/*/* -t {params.output_dir} >> {log} 2>&1")
            #shell("touch -h {params.profile_dir}/*/*")

            # Still waiting to get an answer on this issue:
            # https://groups.google.com/d/msg/snakemake/zU_wkfZ7YCs/GZP0Z_RoAgAJ
            # Until then, I will just create fake file so snakemake is happy
            _message = "Only one file was profiled with %s so there " \
                       "is nothing to merge. But don't worry, you can " \
                       "still use anvi-interacite with the single profile " \
                       "database that is here: %s" \
                       % (wildcards.group, input.profiles[0])
            create_fake_output_files(_message, output)

        elif len(input.profiles) == 1:
            # if only one sample is not empty, but the group size was
            # bigger than 1 then it means that --cluster-contigs was
            # not performed during anvi-profile.
            _message = "Only one sample had reads recruited to %s " \
                       "and hence merging could not occur." \
                       % wildcards.group
            create_fake_output_files(_message, output)

        else:
            shell("anvi-merge {input.profiles} -o {params.output_dir} -c {input.contigs} \
                   {params.sample_name} {params.skip_concoct_binning} \
                   {params.overwrite_output_destinations} >> {log} 2>&1")


rule import_percent_of_reads_mapped:
    version: 1.0
    log: dirs_dict["LOGS_DIR"] + "/{group}-{sample}-import_percent_of_reads_mapped.log"
    input:
        profile = rules.anvi_profile.output.profile
    output:
        flag = touch(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/import_percent_of_reads_mapped.done"),
        total_num_reads_txt = temp(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/total-num-reads.txt"),
        layers_additional_data = temp(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/layers-additional-data.txt"),
        layers_additional_data_updated = temp(dirs_dict["PROFILE_DIR"] + "/{group}/{sample}/layers-additional-data-updated.txt")
    params:
        bowtie_log = dirs_dict["LOGS_DIR"] + "/{group}-{sample}-bowtie.log"
    threads: M.T('import_percent_of_reads_mapped')
    resources: nodes = M.T('import_percent_of_reads_mapped')
    run:
        if not filesnpaths.is_file_exists(params.bowtie_log, dont_raise=True):
            run.warning('We expected to find the log file for bowtie here: %s\
                         but it is not there. This means we can\'t find out what \
                         percentage of short reads mapped to the fasta file. \
                         If you are confused, feel free to contact us. \
                         This file will still be created: %s, but don\'t let \
                         it fool you, nothing was really imported.' % (params.bowtie_log, output.flag))
            for f in output:
                # creating fake output files for the rule
                # this is done just so snakemake thinks that this rule
                # was executed succesfully.
                shell("touch %s" %f)

        else:
            shell("""
                    echo -e "sample\ttotal_num_reads" > {output.total_num_reads_txt}
                    # ok so this whole thing is a terrible hack anyway, so the next two lines
                    # shouldn't surprise you. Basically printing one line with the sample name
                    # and the the number of reads (which is grepped from the log file)
                    echo -e "`echo -n "{params.bowtie_log}" | awk 'BEGIN{{FS="-"}}{{printf("%s\t", $2)}}';\
                            grep 'reads; of these:' {params.bowtie_log} | awk '{{print $1 * 2}}'`" >> {output.total_num_reads_txt}

                    # import this into the profile database. this part is simply
                    # a 'hack' so we can get a file with both 'mapper reads' and
                    # 'total reads' columns in the next step.
                    anvi-import-misc-data {output.total_num_reads_txt} \
                                          -p {input.profile} \
                                          -t layers \
                                          --just-do-it

                    # get the layers additional data table, which has two columns
                    # now.
                    anvi-export-misc-data -p {input.profile} \
                                          -t layers \
                                          -o {output.layers_additional_data}

                    # do an awk one-liner to add a third column, 'percent mapped'.
                    awk '{{if(NR==1) \
                             {{print $0 "\tpercent_mapped"}} \
                          else \
                              {{print $0 "\t" ($2 * 100 / $4)}}}}' \
                        {output.layers_additional_data} > {output.layers_additional_data_updated}

                    # import the new file with three columns
                    anvi-import-misc-data {output.layers_additional_data_updated} \
                                          -p {input.profile} \
                                          -t layers \
                                          --just-do-it
                """)

if not slave_mode:
    # check if all program dependencies are met. for this line to be effective,
    # there should be an initial dry run step (which is the default behavior of
    # the `WorkflowSuperClass`, so you are most likely covered).
    M.check_workflow_program_dependencies(workflow)
