#!/usr/bin/env python
# -*- coding: utf-8
"""
Calculates pN/pS, the metagenomic analogy of dN/dS
Citations: doi:10.1038/nature11711, doi:10.7717/peerj.2959 (and references therein)
"""

import os
import sys
import time
import argparse

import numpy as np
import pandas as pd

import anvio

import anvio.dbops as dbops
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError
from anvio.variabilityops import variability_engines, VariabilityData


__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "Evan Kiefl"
__email__ = "kiefl.evan@gmail.com"
__description__ = ("This program calculates for each gene the ratio of pN/pS (the metagenomic analogy "
                   "of dN/dS; see doi:10.1038/nature11711 and doi:10.7717/peerj.2959) based on metagenomic "
                   "read recruitment, however, unlike standard pN/pS calculations, it relies on "
                   "codons rather than nucleotides for accurate estimations of synonimity.")


progress = terminal.Progress()
run = terminal.Run()


def load_variability(args, contigs_db):
    progress.new("Loading variability")
    var_dict = {}
    columns_to_keep = ['corresponding_gene_call', 'sample_id', 'engine']
    args.columns_to_load = ['corresponding_gene_call', 'sample_id', 'coverage', 'departure_from_consensus']
    for engine, table_path in {'AA': args.saav_table, 'CDN': args.scv_table}.items():
        progress.update("Working on {}".format({'AA':'SAAVs', 'CDN':'SCVs'}[engine]))

        # init variability table as VariabilityData class object
        args.engine = engine
        args.variability_profile = table_path
        var = VariabilityData(args)

        # filter by departure_from_consensus
        var.filter_data(criterion='departure_from_consensus', verbose=False)

        # filter by departure_from_consensus
        var.filter_data(criterion='coverage', verbose=False)

        # add label for engine
        var.data['engine'] = engine

        # optimize speed by subsetting to only necessary columns
        var.data = var.data[columns_to_keep]

        var_dict[engine] = var

    progress.update('Combining both engine data types')
    # concat the two engine-specific dataframes into a master dataframe
    df = pd.concat([var_dict['CDN'].data, var_dict['AA'].data], ignore_index=True)

    # identify gene calls that are partial
    gene_caller_ids_for_partial_gene_calls = [g for g in df['corresponding_gene_call'].unique() if contigs_db.genes_in_contigs_dict[g]['partial']]


    if gene_caller_ids_for_partial_gene_calls:
        df.drop(df[df['corresponding_gene_call'].isin(gene_caller_ids_for_partial_gene_calls)].index, inplace=True)

        progress.reset()
        run.warning("%d of your gene calls were 'partial', and were removed from downstream analyses. Here is the complete list of "
                    "gene calls that were removed: '%s'." % (len(gene_caller_ids_for_partial_gene_calls), ', '.join([str(g) for g in gene_caller_ids_for_partial_gene_calls])))

    progress.end()

    return df


def load_contigs_db(args):
    filesnpaths.is_file_exists(args.contigs_db)
    contigs_db = dbops.ContigsSuperclass(args, r=terminal.Run(verbose=False), p=terminal.Progress(verbose=False))
    contigs_db.init_contig_sequences()
    return contigs_db


def report(args, pNpS, SAAV, sSCV):
    # write it to folder
    pNpS.to_csv(os.path.join(args.output_dir, "pN_pS_ratio.txt"), sep="\t", index = True, index_label = "gene_callers_id")
    sSCV.to_csv(os.path.join(args.output_dir, "sSCV_counts.txt"), sep="\t", index = True, index_label = "gene_callers_id")
    SAAV.to_csv(os.path.join(args.output_dir, "SAAV_counts.txt"), sep="\t", index = True, index_label = "gene_callers_id")

    run.info_single("Done! Contents have been output to the directory '{}'.".format(args.output_dir),
                    nl_before=1,
                    nl_after=1)


def calculate_pN_pS_ratio(args):
    # gen output
    filesnpaths.check_output_directory(args.output_dir)
    filesnpaths.gen_output_directory(args.output_dir)

    # load contigs db and variability tables
    contigs_db = load_contigs_db(args)
    var_dataframe = load_variability(args, contigs_db)

    # only genes and samples that are have at least one SCV / SAAV are known to the program
    gene_caller_ids = var_dataframe['corresponding_gene_call'].unique()
    samples = var_dataframe['sample_id'].unique()

    # creation of SAAV count table, sSCV count table, and SAAV_to_sSCV ratio table
    i = 0
    SAAV = {sample: [] for sample in samples}
    sSCV = {sample: [] for sample in samples}
    SAAV_to_sSCV = {sample: [] for sample in samples}
    progress.new("Calculating SAAV and sSCV counts")
    for gene_callers_id in gene_caller_ids:
        temp = var_dataframe[var_dataframe['corresponding_gene_call'] == gene_callers_id]
        for sample in samples:
            counts = temp.loc[temp['sample_id'] == sample, 'engine'].value_counts()

            SAAV[sample].append(counts.get('AA', 0))
            sSCV[sample].append(counts.get('CDN', 0) - counts.get('AA', 0)) # fully synonymous only
            if counts.get('CDN', 0) < args.minimum_num_variants:
                SAAV_to_sSCV[sample].append(np.nan)
            else:
                with np.errstate(divide='ignore'):
                    SAAV_to_sSCV[sample].append(counts.get('AA', 0) / (counts['CDN'] - counts.get('AA', 0)))

        if i % 100 == 0:
            progress.update('Processed {} of {} genes.'.format(i, len(gene_caller_ids)))
        i += 1
    progress.end()

    # make tables dataframes
    sSCV = pd.DataFrame(sSCV); sSCV.index = gene_caller_ids
    SAAV = pd.DataFrame(SAAV); SAAV.index = gene_caller_ids
    SAAV_to_sSCV = pd.DataFrame(SAAV_to_sSCV); SAAV_to_sSCV.index = gene_caller_ids

    # calc potential synonymity for each gene
    potential_synonymity_ratio = np.zeros(len(gene_caller_ids))
    total_ambiguous_codons = 0
    for i, gene_callers_id in enumerate(gene_caller_ids):
        gene_call = contigs_db.genes_in_contigs_dict[gene_callers_id]

        codon_list_for_gene = utils.get_list_of_codons_for_gene_call(gene_call, contigs_db.contig_sequences)
        syn_potential, non_syn_potential, num_ambiguous_codons = utils.get_synonymous_and_non_synonymous_potential(codon_list_for_gene)

        # we are adding the number of codons in this gene that were not included
        # in the calculations above as they included ambiguous nucleotides
        total_ambiguous_codons += num_ambiguous_codons

        potential_synonymity_ratio[i] = syn_potential / non_syn_potential

    if total_ambiguous_codons:
        run.warning("Please note that %d codons among your %d genes included nucleotides "
                    "that were not unambiguous (i.e., 'N's or any other bases than A, T, "
                    "C, or G) and were excluded from synonymity calculations." % (total_ambiguous_codons, len(gene_caller_ids)))

    # normalize the SAAV_to_sSCV ratio by the potential synonymity ratio to make pN/pS
    pNpS = SAAV_to_sSCV.multiply(potential_synonymity_ratio, axis=0)

    report(args, pNpS, SAAV, sSCV)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__description__)

    groupV = parser.add_argument_group('VARIABILITY', 'You provide two variability tables generated with anvi-gen-variability-profile:\
                                        one for SAAVs (generated with --engine AA) and for SCVs (generated with --engine CDN). They \
                                        must be generated with the same profile and contigs database pair. To be safe, we recommended\
                                        you use the same settings during both commands except for changing --engine AA to --engine CDN\
                                        and the output filename.')
    groupV.add_argument('-a', '--saav-table', help='Filepath to the SAAV table.', metavar='SAAV_FILE')
    groupV.add_argument('-b', '--scv-table', help='Filepath to the SCV table.', metavar='SCV_FILE')
    groupV.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db', {'help':'Filepath to the contigs database used \
                                                      to generate variability tables.'}))

    groupE = parser.add_argument_group('TUNABLES', "Successfully tune one or more of these parameters to unlock the badge 'Advanced anvian'.")
    groupE.add_argument(*anvio.A('min-departure-from-consensus'), **anvio.K('min-departure-from-consensus', {'default': 0.1, 'help': \
                            'Variants (either SCVs or SAAVs) will be ignored if they have a departure from consensus less than this \
                            value. Note: Keep in mind you may have already supplied this parameter during anvi-gen-variability-profile.\
                            The default value is "%(default).2f".'}))
    groupE.add_argument('-i', '--minimum-num-variants', default=4, type=int, required=False, help='Ignore genes with less than this number\
                            of single codon variants. This avoids being impressed by pN/pS values of infinite, when in reality the gene had\
                            a single SAAV and no synonymous SCVs. The default is %(default)d to ensure a default value with some level of \
                            statistical importance.')
    groupE.add_argument('-m', '--min-coverage', default=30, type=int, required=False, help='If the coverage value at a codon is less than \
                            this amount, any SAAVs or SCVs associated with it will be ignored. The default is %(default)d.')

    groupO = parser.add_argument_group('OUTPUT', 'The output of this program is a folder directory with several tables.')
    groupO.add_argument(*anvio.A('output-dir'), **anvio.K('output-dir', {'required':True}))

    args = anvio.get_args(parser)

    try:
        calculate_pN_pS_ratio(args)
    except ConfigError as e:
        print(e)
        sys.exit(1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(1)
