#!/usr/bin/env python
# -*- coding: utf-8
"""
Calculates pN/pS, the metagenomic analogy of dN/dS
Citations: doi:10.1038/nature11711, doi:10.7717/peerj.2959 (and references therein)
"""

import os
import sys
import time
import argparse

import numpy as np
import pandas as pd

import anvio

import anvio.dbops as dbops
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError
from anvio.variabilityops import variability_engines, VariabilityData


__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "Evan Kiefl"
__email__ = "kiefl.evan@gmail.com"


progress = terminal.Progress()
run = terminal.Run()


def load_variability(args):
    progress.new("Loading variability")
    var_dict = {}
    columns_to_keep = ['corresponding_gene_call', 'sample_id', 'engine']
    args.columns_to_load = ['corresponding_gene_call', 'sample_id', 'coverage', 'departure_from_consensus']
    for engine, table_path in {'AA': args.saav_table, 'CDN': args.scv_table}.items():
        progress.update("Working on {}".format({'AA':'SAAVs', 'CDN':'SCVs'}[engine]))

        # init variability table as VariabilityData class object
        args.engine = engine
        args.variability_profile = table_path
        var = VariabilityData(args)

        # filter by departure_from_consensus
        var.filter_data(criterion='departure_from_consensus', verbose=False)

        # filter by departure_from_consensus
        var.filter_data(criterion='coverage', verbose=False)

        # add label for engine
        var.data['engine'] = engine

        # optimize speed by subsetting to only necessary columns
        var.data = var.data[columns_to_keep]

        var_dict[engine] = var

    progress.update('Combining both engine data types')
    # concat the two engine-specific dataframes into a master dataframe
    df = pd.concat([var_dict['CDN'].data, var_dict['AA'].data], ignore_index=True)

    progress.end()
    return df


def load_contigs_db(args):
    filesnpaths.is_file_exists(args.contigs_db)
    contigs_db = dbops.ContigsSuperclass(args, r=terminal.Run(verbose=False), p=terminal.Progress(verbose=False))
    contigs_db.init_contig_sequences()
    return contigs_db


def report(args, pNpS, SAAV, sSCV):
    # write it to folder
    pNpS.to_csv(os.path.join(args.output_dir, "pN_pS_ratio.txt"), sep="\t", index = True, index_label = "gene_callers_id")
    sSCV.to_csv(os.path.join(args.output_dir, "sSCV_counts.txt"), sep="\t", index = True, index_label = "gene_callers_id")
    SAAV.to_csv(os.path.join(args.output_dir, "SAAV_counts.txt"), sep="\t", index = True, index_label = "gene_callers_id")

    run.info_single("Done! Contents have been output to the directory '{}'.".format(args.output_dir),
                    nl_before=1,
                    nl_after=1)


def calculate_pN_pS_ratio(args):
    # gen output
    filesnpaths.check_output_directory(args.output_dir)
    filesnpaths.gen_output_directory(args.output_dir)

    # load contigs db and variability tables
    contigs_db = load_contigs_db(args)
    var_dataframe = load_variability(args)

    # only genes and samples that are have at least one SCV / SAAV are known to the program
    genes = var_dataframe['corresponding_gene_call'].unique()
    samples = var_dataframe['sample_id'].unique()

    # creation of SAAV count table, sSCV count table, and SAAV_to_sSCV ratio table
    i = 0
    SAAV = {sample: [] for sample in samples}
    sSCV = {sample: [] for sample in samples}
    SAAV_to_sSCV = {sample: [] for sample in samples}
    progress.new("Calculating SAAV and sSCV counts")
    for gene in genes:
        temp = var_dataframe[var_dataframe['corresponding_gene_call'] == gene]
        for sample in samples:
            counts = temp.loc[temp['sample_id'] == sample, 'engine'].value_counts()

            SAAV[sample].append(counts.get('AA', 0))
            sSCV[sample].append(counts.get('CDN', 0) - counts.get('AA', 0)) # fully synonymous only
            if counts.get('CDN', 0) < args.minimum_num_variants:
                SAAV_to_sSCV[sample].append(np.nan)
            else:
                with np.errstate(divide='ignore'):
                    SAAV_to_sSCV[sample].append(counts.get('AA', 0) / (counts['CDN'] - counts.get('AA', 0)))

        if i % 100 == 0:
            progress.update('Processed {} of {} genes.'.format(i, len(genes)))
        i += 1
    progress.end()

    # make tables dataframes
    sSCV = pd.DataFrame(sSCV); sSCV.index = genes
    SAAV = pd.DataFrame(SAAV); SAAV.index = genes
    SAAV_to_sSCV = pd.DataFrame(SAAV_to_sSCV); SAAV_to_sSCV.index = genes

    # calc potential synonymity for each gene
    potential_synonymity_ratio = np.zeros(len(genes))
    for i, gene in enumerate(genes):
        gene_call = contigs_db.genes_in_contigs_dict[gene]
        codon_list_for_gene = utils.get_list_of_codons_for_gene_call(gene_call, contigs_db.contig_sequences)
        syn_potential, non_syn_potential = utils.get_synonymous_and_non_synonymous_potential(codon_list_for_gene)
        potential_synonymity_ratio[i] = syn_potential / non_syn_potential

    # normalize the SAAV_to_sSCV ratio by the potential synonymity ratio to make pN/pS
    pNpS = SAAV_to_sSCV.multiply(potential_synonymity_ratio, axis=0)

    report(args, pNpS, SAAV, sSCV)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Extract information for variable positions')

    groupV = parser.add_argument_group('VARIABILITY', 'Two variability tables generated from \
                                                       anvi-gen-variability-table are required. One \
                                                       of SAAVs (generated with --engine AA) and one \
                                                       of SCVs (generated with --engine CDN). They \
                                                       must be generated with the same profile database and \
                                                       the exact same set of genes in the contigs \
                                                       database. To be safe, it is highly recommended you \
                                                       use the same settings during both commands except \
                                                       for changing --engine AA to --engine CDN and the \
                                                       output filename.')
    groupO = parser.add_argument_group('OUTPUT',      'The output of this program is a folder \
                                                       directory with several tables.')
    groupE = parser.add_argument_group('TUNABLES',    'Successfully tune one or more of these \
                                                       parameters to unlock the "Advanced anvi\'an" \
                                                       achievement.')

    groupV.add_argument('-a', '--saav-table', help='Filepath to the SAAV table.')
    groupV.add_argument('-b', '--scv-table', help='Filepath to the SCV table.')
    groupV.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db',
                                                     {'help':'Filepath to the contigs database used \
                                                      to generate variability tables.'}))
    groupE.add_argument(*anvio.A('min-departure-from-consensus'), **anvio.K('min-departure-from-consensus',
                                                     {'default': 0.1, 'help': 'Variants (either SCVs or \
                                                       SAAVs) will be ignored if they have a departure \
                                                       from consensus less than this value. Note: Keep \
                                                       in mind you may have already supplied this parameter \
                                                       during anvi-gen-variability-profile. Default is 0.1.'}))
    groupE.add_argument('-i', '--minimum-num-variants', default=4, type=int, required=False, help='Ignore genes \
                                                       with less than this number of single codon \
                                                       variants. This avoids being impressed by \
                                                       pN/pS values of infinite, when really all \
                                                       that happened was a gene had 1 SAAV and 0 \
                                                       synonymous SCVs. The default is 4 to ensure \
                                                       some level of statistical importance.')
    groupE.add_argument('-m', '--min-coverage', default=30, type=int, required=False,
                                                       help='If the coverage value at a codon \
                                                       is less than this amount, any SAAVs or SCVs \
                                                       associated with it will be ignored.')
    groupO.add_argument(*anvio.A('output-dir'), **anvio.K('output-dir', {'required':True}))

    args = anvio.get_args(parser)

    try:
        calculate_pN_pS_ratio(args)
    except ConfigError as e:
        print(e)
        sys.exit(1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(1)
