#!/usr/bin/env python
# -*- coding: utf-8
"""Return frequencies of amino acids in a gene

   Takes a bunch of BAM files, and a unique gene caller ID to count
   AA linkmer frequencies"""

import sys

from collections import Counter

import anvio
import anvio.utils as utils
import anvio.bamops as bamops
import anvio.terminal as terminal
import anvio.constants as constants
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError
from anvio.dbops import ContigsSuperclass


__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__requires__ = ['contigs-db', 'bam-file',]
__provides__ = ['codon-frequencies-txt', 'aa-frequencies-txt',]


class ReportCodonFrequencies:
    def __init__(self, args, run=terminal.Run(), progress=terminal.Progress()):
        self.args = args
        self.run = run
        self.progress = progress

        A = lambda x: args.__dict__[x] if x in args.__dict__ else None
        self.congits_db_path = A('contigs_db')
        self.bam_file_path = A('bam_file')
        self.gene_caller_id = A('gene_caller_id')
        self.return_AA_frequencies_instead = A('return_AA_frequencies_instead')
        self.percent_normalize = A('percent_normalize')
        self.merens_codon_normalization = A('merens_codon_normalization')
        self.output_file_path = A('output_file')

        filesnpaths.is_output_file_writable(self.output_file_path)

        if self.bam_file_path:
            filesnpaths.is_file_exists(self.bam_file_path)

        if self.merens_codon_normalization and self.percent_normalize:
            raise ConfigError("You can't use both `--merens-codon-normalization` and `--percent-normalize`. Please\
                               read the help menu and pick one (you will get bonus points if you pick meren's\
                               normalization because why not)")

        if self.merens_codon_normalization and self.return_AA_frequencies_instead:
            raise ConfigError("The flag `--merens-codon-normalization` is only relevant if you are working with codon\
                               frequencies :/")

        self.c = ContigsSuperclass(args)

        self.gene_caller_ids = set([])
        if self.gene_caller_id:
            if self.gene_caller_id not in self.c.genes_in_contigs_dict:
                raise ConfigError("Your contigs database named '%s' does not know anythinga bout the gene caller id\
                                   '%s' :/" % (self.c.a_meta['project_name'], str(self.gene_caller_id)))
            else:
                self.gene_caller_ids = [self.gene_caller_id]
        else:
            self.gene_caller_ids = set(self.c.genes_in_contigs_dict.keys())

        if self.return_AA_frequencies_instead:
            self.items = sorted(list(set(constants.codon_to_AA.values())))
        else:
            self.items = []
            for amino_acid in constants.AA_to_codons:
                self.items.extend(constants.AA_to_codons[amino_acid])

        if self.bam_file_path:
            self.from_BAM_file()
        else:
            self.from_contigs_db()


    def from_contigs_db(self):
        if self.gene_caller_id:
            self.c.init_contig_sequences(gene_caller_ids_of_interest=self.gene_caller_ids)
        else:
            self.c.init_contig_sequences()

        residue_frequencies = {}
        partial_genes_skipped = set([])

        F = utils.get_list_of_AAs_for_gene_call if self.return_AA_frequencies_instead else utils.get_list_of_codons_for_gene_call

        for gene_callers_id in self.gene_caller_ids:
            gene_call = self.c.genes_in_contigs_dict[gene_callers_id]

            if gene_call['partial']:
                partial_genes_skipped.add(gene_callers_id)
                continue

            residue_frequencies[gene_callers_id] = Counter(F(gene_call, self.c.contig_sequences))

        if self.percent_normalize:
            for gene_callers_id in residue_frequencies:
                total = sum(residue_frequencies[gene_callers_id].values())
                residue_frequencies[gene_callers_id] = Counter(dict([(r, round(residue_frequencies[gene_callers_id][r] * 100.0 / total, 3)) \
                                                                                    for r in residue_frequencies[gene_callers_id]]))
        elif self.merens_codon_normalization:
            for gene_callers_id in residue_frequencies:
                for amino_acid in constants.AA_to_codons:
                    codons_of_interest = constants.AA_to_codons[amino_acid]
                    codons_of_interest_total = sum([residue_frequencies[gene_callers_id][r] for r in constants.AA_to_codons[amino_acid]])

                    for codon in codons_of_interest:
                        if codons_of_interest_total:
                            residue_frequencies[gene_callers_id][codon] = round(residue_frequencies[gene_callers_id][codon] * 100.0 / codons_of_interest_total, 3)

        if len(partial_genes_skipped):
            self.run.warning("%d of %d genes were skipped and will not be in the final report since they were\
                             'partial' gene calls." % (len(partial_genes_skipped), len(self.gene_caller_ids)))

        if not len(residue_frequencies):
            raise ConfigError("Anvi'o has no residue frequencies to work with :(")

        utils.store_dict_as_TAB_delimited_file(residue_frequencies, self.output_file_path, headers=['gene_callers_id'] + self.items)
        self.run.info('Output file', self.output_file_path)


    def from_BAM_file(self):
        if len(self.gene_caller_ids) > 1:
            raise ConfigError("If you are working with a BAM file, you must unfortunately declare a single gene caller\
                               ID since this anvi'o class is not yet smart to be able to deal with multiple gene calls when\
                               working with BAM files. Please send us an e-mail if you need this feature, and we will\
                               implement it for you!")

        gene_callers_id = self.gene_caller_ids[0]
        bam_file_object = bamops.BAMFileObject(self.bam_file_path).get()
        self.c.init_contig_sequences(gene_caller_ids_of_interest=set([gene_callers_id]))

        gene_call = self.c.genes_in_contigs_dict[gene_callers_id]

        if gene_call['partial']:
            raise ConfigError('This seems to be a partial gene call. Sorry :/ Partial gene calls are tricky to work with\
                               since not every gene caller is able to report appropriate frame within to reliably report\
                               codons.')

        self.run.info('Working with', 'Amino acids' if self.return_AA_frequencies_instead else 'Codons')

        contig_sequence = self.c.contig_sequences[gene_call['contig']]['sequence']

        codon_frequencies = bamops.CodonFrequencies()

        d = {}

        self.progress.new('Busy code is busy')
        self.progress.update('Generating %s frequencies dict ...' % 'amino acid' if self.return_AA_frequencies_instead else 'codon')
        frequencies_dict = codon_frequencies.process_gene_call(bam_file_object,
                                                               gene_call,
                                                               contig_sequence,
                                                               None,
                                                               self.return_AA_frequencies_instead)

        self.progress.update('Working on the output ...')
        for codon_order in frequencies_dict:
            entry = frequencies_dict[codon_order]
            d[codon_order] = {'reference': entry['reference'], 'coverage': entry['coverage'],
                              'contig_name': gene_call['contig'], 'start': gene_call['start'],
                              'stop': gene_call['stop'], 'direction': gene_call['direction']}

            for item in self.items:
                d[codon_order][item] = entry['frequencies'][item]

        header = ['codon_order_in_gene', 'contig_name', 'start', 'stop', 'direction', 'reference', 'coverage'] + self.items
        self.progress.update('Storing output ...')
        utils.store_dict_as_TAB_delimited_file(d, self.output_file_path, header)

        self.progress.end()
        self.run.info('Output', args.output_file)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Get amino acid or codon frequencies of genes in a contigs database, or or get the same information\
                                                  considering the make up of the environmnet captured in a BAM file through the mapped reads.\
                                                  What is the difference? The first one will give you AA or CDN frequencies given a set of genes. In\
                                                  contrast, the latter will also take a BAM file, and will give you AA or CDN frequencies of the\
                                                  environment. If you want to know the frequencies of residues of a "gene" you do not need a BAM\
                                                  file. If you want the frequencies you observe in actual read recruitment results, then you need\
                                                  provide a BAM file. The latter will give you frequencies of CDN or AA "linkmers" an anvio trick\
                                                  that, for each codon position, enables you to work with only the short reads that fully cover the\
                                                  codon to ensure molecular linkage between bases (well, assuming there are no sequencing errors).')

    groupA = parser.add_argument_group('INPUT DATABASE', 'The contigs database. Clearly those genes must be read from somewhere.')
    groupA.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))

    groupB = parser.add_argument_group('BAM FILE?', "Optional BAM file to recover CDN/AA frequencies among the mapped reads. It is OK to not\
                                                  provide one. Read the help menu of the program for more information.")
    parser.add_argument('-b', '--bam-file', metavar = 'INPUT_BAM', default = None, help = 'Sorted and indexed BAM file to analyze.')


    groupC = parser.add_argument_group('COMMON OPTIONALS', "Important things to read never end. Stupid science.")
    groupC.add_argument(*anvio.A('gene-caller-id'), **anvio.K('gene-caller-id', {'help': "OK. You can declare a single gene caller ID if you wish, in\
                                                                which case anvi'o would only return results for a single gene call. If you don't declare\
                                                                anything, well, you must be prepared to brace yourself if you are working with a very\
                                                                large contigs database with hundreds of thousands of genes."}))

    groupC.add_argument(*anvio.A('return-AA-frequencies-instead'), **anvio.K('return-AA-frequencies-instead'))
    groupC.add_argument(*anvio.A('output-file'), **anvio.K('output-file', {'required': True}))


    groupD = parser.add_argument_group('NON-BAM OPTIONALS', "Optional flags that will only be taken into consideration if you are not working with BAM file.")
    groupD.add_argument('--percent-normalize', default=False, action="store_true", help = "Instead of actual counts, report percent-normalized\
                                                                frequencies per gene (because you are too lazy to do things the proper way in R).")
    groupD.add_argument('--merens-codon-normalization', default=False, action="store_true", help = "This is a flag to percent normalize codon frequenies within those\
                                                                that encode for the same amino acid. It is different from the flag --percent-normalize, since it\
                                                                does not percent normalize frequencies of codons within a gene based on all codon frequencies. Clearly\
                                                                this flag is not applicable if you wish to work with boring amino acids. WHO WORKS WITH AMINO ACIDS\
                                                                ANYWAY.")

    args = anvio.get_args(parser)

    try:
        ReportCodonFrequencies(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
