#!/usr/bin/env python
# -*- coding: utf-8

import sys
import numpy as numpy

import anvio
import anvio.panops as panops
import anvio.terminal as terminal
import anvio.summarizer as summarizer
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2016, The anvio Project"
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()
pp = terminal.pretty_print


def main(args):
    fraction_of_median_coverage = float(args.fraction_of_median_coverage)
    min_detection = float(args.min_detection)
    output_summary_path = args.output_file_prefix + '.txt'
    output_per_gene_summary_path = args.output_file_prefix + '-GENES.txt'
    filesnpaths.is_output_file_writable(output_summary_path)

    progress.new('Initializing stuff')
    progress.update('Loading gneome descriptions')
    G = panops.GenomeStorage(args)
    G.load_genomes_descriptions(skip_functions=True)

    progress.update('Loading files for pangenome')
    ARGS = summarizer.ArgsTemplateForSummarizerClass()
    ARGS.pan_db = args.pan_db
    ARGS.genomes_storage = args.genomes_storage
    ARGS.skip_check_collection_name = True
    ARGS.skip_init_functions = True
    pan_summary = summarizer.PanSummarizer(ARGS)
    progress.end()

    run.info("Internal genomes", "%d found." % len(G.internal_genome_names))
    run.info('Fraction of median coverage for core genes', fraction_of_median_coverage)
    run.info('Min detection of a genome in at last one metagenome', min_detection)

    progress.new('Initializing internal genomes')

    # to not initialize things over and over again:
    unique_profile_db_path_to_internal_genome_name = G.get_unique_profile_db_path_to_internal_genome_name_dict()

    for profile_db_path in unique_profile_db_path_to_internal_genome_name:

        # a shitty limitation here before we start:
        collection_names = set([G.genomes[genome_name]['collection_id'] for genome_name in unique_profile_db_path_to_internal_genome_name[profile_db_path]])
        if len(collection_names) != 1:
            raise ConfigError("You have to have the same collection for each bin originate from the same profile db.")

    per_gene_output_file = open(output_per_gene_summary_path, 'w')
    per_gene_output_file.write('genome\tgene_callers_id\tclass\n')

    progress.update('Computing gene presence data ...')
    gene_presence_in_the_environment_dict = {}
    for profile_db_path in unique_profile_db_path_to_internal_genome_name:
        collection_name = G.genomes[unique_profile_db_path_to_internal_genome_name[profile_db_path][0]]['collection_id']
        profile_db_path = G.genomes[unique_profile_db_path_to_internal_genome_name[profile_db_path][0]]['profile_db_path']
        contigs_db_path = G.genomes[unique_profile_db_path_to_internal_genome_name[profile_db_path][0]]['contigs_db_path']

        ARGS = summarizer.ArgsTemplateForSummarizerClass()
        ARGS.profile_db = profile_db_path
        ARGS.contigs_db = contigs_db_path
        ARGS.collection_name = collection_name

        summary = summarizer.ProfileSummarizer(ARGS)
        summary.init_collection_profile(collection_name)

        for genome_name in unique_profile_db_path_to_internal_genome_name[profile_db_path]:
            # for each genome, first we will see whether it is detected in at least one metagenome
            detection_across_metagenomes = summary.collection_profile[genome_name]['detection']
            num_metagenomes_above_min_detection = [m for m in detection_across_metagenomes if detection_across_metagenomes[m] > min_detection]
            not_enough_detection = False if len(num_metagenomes_above_min_detection) else True

            gene_presence_in_the_environment_dict[genome_name] = {}
            split_names_of_interest = G.get_split_names_of_interest_for_internal_genome(G.genomes[genome_name])

            genome_bin_summary = summarizer.Bin(summary, genome_name, split_names_of_interest)
            gene_coverages_across_samples = genome_bin_summary.get_gene_coverages_across_samples_dict()

            # at this point we have all the genes in the genome bin. what we need is to characterize their detection. first,
            # summarize the coverage of each gene in all samples:
            sum_gene_coverages_across_samples = dict([(gene_callers_id, sum(gene_coverages_across_samples[gene_callers_id].values())) for gene_callers_id in gene_coverages_across_samples])

            # now we will identify the median coverage
            median_coverage_across_samples = numpy.median(list(sum_gene_coverages_across_samples.values()))

            # now we will store decide whether a gene found in this genome is also found in the environment, and store that
            # information into `gene_presence_in_the_environment_dict`, and move on to the next stage.
            for gene_caller_id in sum_gene_coverages_across_samples:
                if not_enough_detection:
                    _class = 'NA'
                elif sum_gene_coverages_across_samples[gene_caller_id] < median_coverage_across_samples * fraction_of_median_coverage:
                    _class = 'EDG'
                else:
                    _class = 'ECG'

                gene_presence_in_the_environment_dict[genome_name][gene_caller_id] = _class
                per_gene_output_file.write('%s\t%d\t%s\n' % (genome_name, gene_caller_id, _class))

    per_gene_output_file.close()

    progress.update('Computing ratio of genes present/absent per PC data ...')
    gene_status_frequencies_in_PC = {}
    for pc_name in pan_summary.protein_clusters:
        status = {'EDG': 0, 'ECG': 0, 'NA': 0}
        for genome_name in pan_summary.protein_clusters[pc_name]:
            for gene_caller_id in pan_summary.protein_clusters[pc_name][genome_name]:
                status[gene_presence_in_the_environment_dict[genome_name][gene_caller_id]] += 1
        gene_status_frequencies_in_PC[pc_name] = status

    progress.update('Storing the output')
    categories = ['EDG', 'ECG', 'NA']
    output = open(output_summary_path, 'w')
    output.write('pc_name\tDetection!%s\n' % (';'.join(categories)))
    for pc_name in gene_status_frequencies_in_PC:
        freqs = ';'.join([str(gene_status_frequencies_in_PC[pc_name][cat]) for cat in categories])
        output.write('%s\t%s\n' % (pc_name, freqs))
    output.close()

    progress.end()

    run.info('PC summary file', output_summary_path)
    run.info('All gene classification', output_per_gene_summary_path)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description="Quantify the detection of genes in genomes in metagenomes to identify the environmental\
                                                  core. This is a helper script for anvi'o metapangenomic workflow.")
    groupA = parser.add_argument_group("INTERNAL GENOMES", "Genome bins stored in an anvi'o profile databases as collections.")
    groupA.add_argument('-i', '--internal-genomes', metavar = 'FILE', default = None,
                        help = "A four-column TAB-delimited flat text file. This file should be identical to the internal\
                                genomes file you used for your pangenomics analysis.")

    groupB = parser.add_argument_group("CRITERION FOR DETECTION", "This is tricky. What we want to do is to identify genes that are\
                                        occurring uniformly across samples.")
    groupB.add_argument('--fraction-of-median-coverage', metavar="FLOAT", default=0.25, type=float, help="The value set here\
                        will be used to remove a gene if its total coverage across environments is less than the median coverage\
                        of all genes multiplied by this value. The default is 0.25, which means, if the median total coverage of\
                        all genes across all samples is 100X, then, a gene with a total coverage of less than 25X across all\
                        samples will be assumed not a part of the 'environmental core'.")
    parser.add_argument('--min-detection', metavar="FLOAT", default=0.50, type=float, help="For this entire thing to work, the\
                        genome you are focusing on should be detected in at least one metagenome. If that is not the case, it would\
                        mean that you do not have any sample that represents the niche for this organism (or you do not have enough\
                        depth of coverage) to investigate the detection of genes in the environment. By default, this script requires\
                        at least '0.5' of the genome to be detected in at least one metagenome. This parameter allows you to change\
                        that. 0 would mean no detection test required, 1 would mean the entire genome must be detected.")


    groupC = parser.add_argument_group("PAN GENOME INFO", "Files for the pangenome.")
    groupC.add_argument(*anvio.A('pan-db'), **anvio.K('pan-db', {'required': False}))
    groupC.add_argument(*anvio.A('genomes-storage'), **anvio.K('genomes-storage', {'required': False}))

    groupD = parser.add_argument_group("OUTPUT", "Give it a nice prefix. This prefix will be used to generate the additional\
                                                  data matrix for your display, and an extra file for per-gene classification.")
    groupD.add_argument(*anvio.A('output-file-prefix'), **anvio.K('output-file-prefix', {'required': True}))

    args = parser.parse_args()

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-1)
