#!/usr/bin/env python
# -*- coding: utf-8
"""Return frequencies of amino acids in a gene

   Takes a bunch of BAM files, and a unique gene caller ID to count
   AA linkmer frequencies"""

import sys

from collections import Counter

import anvio
import anvio.utils as utils
import anvio.bamops as bamops
import anvio.terminal as terminal
import anvio.constants as constants
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError
from anvio.dbops import ContigsSuperclass


__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__requires__ = ['contigs-db']
__provides__ = ['codon-frequencies-txt', 'aa-frequencies-txt',]
__description__ = ("Get amino acid or codon frequencies of genes in a contigs database.")


class ReportCodonFrequencies:
    def __init__(self, args, run=terminal.Run(), progress=terminal.Progress()):
        self.args = args
        self.run = run
        self.progress = progress

        A = lambda x: args.__dict__[x] if x in args.__dict__ else None
        self.bam_file_path = A('bam_file')
        self.gene_caller_id = A('gene_caller_id')
        self.return_AA_frequencies_instead = A('return_AA_frequencies_instead')
        self.percent_normalize = A('percent_normalize')
        self.merens_codon_normalization = A('merens_codon_normalization')
        self.output_file_path = A('output_file')

        filesnpaths.is_output_file_writable(self.output_file_path)

        if self.merens_codon_normalization and self.percent_normalize:
            raise ConfigError("You can't use both `--merens-codon-normalization` and `--percent-normalize`. Please "
                              "read the help menu and pick one (you will get bonus points if you pick meren's "
                              "normalization because why not)")

        if self.merens_codon_normalization and self.return_AA_frequencies_instead:
            raise ConfigError("The flag `--merens-codon-normalization` is only relevant if you are working with codon "
                              "frequencies :/")

        self.c = ContigsSuperclass(args)

        self.gene_caller_ids = set([])
        if self.gene_caller_id:
            if self.gene_caller_id not in self.c.genes_in_contigs_dict:
                raise ConfigError("Your contigs database named '%s' does not know anything about the gene caller id "
                                  "'%s' :/" % (self.c.a_meta['project_name'], str(self.gene_caller_id)))
            else:
                self.gene_caller_ids = [self.gene_caller_id]
        else:
            self.gene_caller_ids = set(self.c.genes_in_contigs_dict.keys())

        if self.return_AA_frequencies_instead:
            self.items = sorted(list(set(constants.codon_to_AA.values())))
        else:
            self.items = []
            for amino_acid in constants.AA_to_codons:
                self.items.extend(constants.AA_to_codons[amino_acid])

        self.process()


    def process(self):
        if self.gene_caller_id:
            self.c.init_contig_sequences(gene_caller_ids_of_interest=self.gene_caller_ids)
        else:
            self.c.init_contig_sequences()

        residue_frequencies = {}
        partial_genes_skipped = set([])

        F = utils.get_list_of_AAs_for_gene_call if self.return_AA_frequencies_instead else utils.get_list_of_codons_for_gene_call

        for gene_callers_id in self.gene_caller_ids:
            gene_call = self.c.genes_in_contigs_dict[gene_callers_id]

            if gene_call['partial']:
                partial_genes_skipped.add(gene_callers_id)
                continue

            residue_frequencies[gene_callers_id] = Counter(F(gene_call, self.c.contig_sequences))

        if self.percent_normalize:
            for gene_callers_id in residue_frequencies:
                total = sum(residue_frequencies[gene_callers_id].values())
                residue_frequencies[gene_callers_id] = Counter(dict([(r, round(residue_frequencies[gene_callers_id][r] * 100.0 / total, 3)) \
                                                                                    for r in residue_frequencies[gene_callers_id]]))
        elif self.merens_codon_normalization:
            for gene_callers_id in residue_frequencies:
                for amino_acid in constants.AA_to_codons:
                    codons_of_interest = constants.AA_to_codons[amino_acid]
                    codons_of_interest_total = sum([residue_frequencies[gene_callers_id][r] for r in constants.AA_to_codons[amino_acid]])

                    for codon in codons_of_interest:
                        if codons_of_interest_total:
                            residue_frequencies[gene_callers_id][codon] = round(residue_frequencies[gene_callers_id][codon] * 100.0 / codons_of_interest_total, 3)

        if len(partial_genes_skipped):
            self.run.warning("%d of %d genes were skipped and will not be in the final report since they were "
                            "'partial' gene calls." % (len(partial_genes_skipped), len(self.gene_caller_ids)))

        if not len(residue_frequencies):
            raise ConfigError("Anvi'o has no residue frequencies to work with :(")

        utils.store_dict_as_TAB_delimited_file(residue_frequencies, self.output_file_path, headers=['gene_callers_id'] + self.items)
        self.run.info('Output file', self.output_file_path)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description=__description__)

    groupA = parser.add_argument_group('INPUT DATABASE', 'The contigs database. Clearly those genes must be read from somewhere.')
    groupA.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))

    groupC = parser.add_argument_group('OPTIONALS', "Important things to read never end. Stupid science.")
    groupC.add_argument(*anvio.A('gene-caller-id'), **anvio.K('gene-caller-id', {'help': "OK. You can declare a single gene caller ID if you wish, in\
                                                                which case anvi'o would only return results for a single gene call. If you don't declare\
                                                                anything, well, you must be prepared to brace yourself if you are working with a very\
                                                                large contigs database with hundreds of thousands of genes."}))

    groupC.add_argument(*anvio.A('return-AA-frequencies-instead'), **anvio.K('return-AA-frequencies-instead'))
    groupC.add_argument(*anvio.A('output-file'), **anvio.K('output-file', {'required': True}))


    groupC.add_argument('--percent-normalize', default=False, action="store_true", help = "Instead of actual counts, report percent-normalized\
                                                                frequencies per gene (because you are too lazy to do things the proper way in R).")
    groupC.add_argument('--merens-codon-normalization', default=False, action="store_true", help = "This is a flag to percent normalize codon frequenies within those\
                                                                that encode for the same amino acid. It is different from the flag --percent-normalize, since it\
                                                                does not percent normalize frequencies of codons within a gene based on all codon frequencies. Clearly\
                                                                this flag is not applicable if you wish to work with boring amino acids. WHO WORKS WITH AMINO ACIDS\
                                                                ANYWAYS.")

    args = anvio.get_args(parser)

    try:
        ReportCodonFrequencies(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
