#!/usr/bin/env python
# -*- coding: utf-8
"""A script to export a FASTA files from contigs databases and compute ani."""

import os
import sys
import shutil
import hashlib
import argparse

import anvio
import anvio.db as db
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.clustering as clustering
import anvio.summarizer as summarizer
import anvio.filesnpaths as filesnpaths
import anvio.genomedescriptions as genomedescriptions

from anvio.drivers import pyani
from anvio.tables.miscdata import TableForLayerAdditionalData
from anvio.tables.miscdata import TableForLayerOrders

from anvio.errors import ConfigError, FilesNPathsError
import anvio.errors


__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "Ozcan Esen"
__email__ = "ozcanesen@gmail.com"
__requires__ = ['external-genomes', 'internal-genomes', 'pan-db']
__provides__ = ['ani-txt', 'ani']

run = terminal.Run()

def restore_names_in_dict(input_dict, conversion_dict):
    """
    Takes dictionary that contains hashes as keys
    and replaces it back to genome names using conversion_dict.

    If value is dict, it calls itself.
    """
    new_dict = {}
    for key, value in input_dict.items():
        if isinstance(value, dict):
            value = restore_names_in_dict(value, conversion_dict)

        if key in conversion_dict:
            new_dict[conversion_dict[key]] = value
        else:
            new_dict[key] = value

    return new_dict


def main(args):
    # some quick checks
    if args.min_alignment_fraction < 0 or args.min_alignment_fraction > 1:
        raise ConfigError("The minimum alignment fraction must be a value between 0.0 and 1.0. %.2f does\
                           not work for anvi'o :/" % args.min_alignment_fraction)

    if args.significant_alignment_length and args.significant_alignment_length < 0:
        raise ConfigError("You missed concept :/ Alignment length can't be smaller than 0.")

    if args.significant_alignment_length and not args.min_alignment_fraction:
        raise ConfigError("Using the --significant-alignment-length parameter Without the --min-alignment-fraction\
                           parameter does not make any sense. But how could you know that unless you read the help\
                           menu? Yep. Anvi'o is calling the bioinformatics police on you :(")

    filesnpaths.check_output_directory(args.output_dir, ok_if_exists=False)

    genome_desc = genomedescriptions.GenomeDescriptions(args, run = terminal.Run(verbose=False))
    genome_desc.load_genomes_descriptions(skip_functions=True, init=False)

    program = pyani.PyANI(args)

    temp_dir = filesnpaths.get_temp_directory_path()

    run.info("Genomes found", len(genome_desc.genomes))
    run.info("Temporary FASTA output directory", temp_dir)
    run.info("Output directory", os.path.abspath(args.output_dir))

    hash_to_name = {}
    genome_names = set([])

    for genome_name in genome_desc.genomes:
        genome_names.add(genome_name)
        contigs_db_path = genome_desc.genomes[genome_name]['contigs_db_path']
        hash_for_output_file = hashlib.sha256(genome_name.encode('utf-8')).hexdigest()
        hash_to_name[hash_for_output_file] = genome_name

        if 'bin_id' in genome_desc.genomes[genome_name]:
            # Internal genome
            bin_id = genome_desc.genomes[genome_name]['bin_id']
            collection_id = genome_desc.genomes[genome_name]['collection_id']
            profile_db_path = genome_desc.genomes[genome_name]['profile_db_path']

            class Args: None
            summary_args = Args()

            summary_args.profile_db = profile_db_path
            summary_args.contigs_db = contigs_db_path
            summary_args.collection_name = collection_id
            summary_args.quick = True

            summary = summarizer.ProfileSummarizer(summary_args, r=terminal.Run(verbose=False))
            summary.init()

            bin_summary = summarizer.Bin(summary, bin_id)

            with open(os.path.join(temp_dir, hash_for_output_file + '.fa'), 'w') as fasta:
                fasta.write(bin_summary.get_bin_sequence())
        else:
            # External genome
            utils.export_sequences_from_contigs_db(contigs_db_path, os.path.join(temp_dir, hash_for_output_file + '.fa'))

    results = program.run_command(temp_dir)
    results = restore_names_in_dict(results, hash_to_name)

    if anvio.DEBUG:
        import json
        for report in results:
            run.warning(None, header=report)
            print(json.dumps(results[report], indent=2))

    # let's go through the results dictionary to remove very weak hits.
    if args.min_alignment_fraction:
        if 'alignment_coverage' not in results:
            raise ConfigError("You asked anvi'o to remove weak hits through the --min-alignment-fraction\
                               parameter, but the results dictionary does not contain any information about\
                               alignment fractions :/ These are the items anvi'o found instead: '%s'. Please let a\
                               developer know about this if this doesn't make any sense." % (', '.join(results.keys())))

        if args.significant_alignment_length is not None and 'alignment_lengths' not in results:
            raise ConfigError("The pyANI results do not contain any alignment lengths data. Perhaps the method you\
                               used for pyANI does not produce such data. Well. That's OK. But then you can't use the\
                               --significant-alignment-length parameter :/")

        d = results['alignment_coverage']
        l = results['alignment_lengths']

        # in this list we will keep the tuples of genome-genome associations
        # that need to be set to zero in all result dicts:
        genome_hits_to_zero = []
        those_that_anvio_wanted_to_remove = 0
        those_that_saved_by_significant_length_param = 0
        for g1 in d:
            for g2 in d:
                if g1 == g2:
                    continue

                if float(d[g1][g2]) < args.min_alignment_fraction or float(d[g2][g1]) < args.min_alignment_fraction:
                    those_that_anvio_wanted_to_remove += 1

                    if args.significant_alignment_length and min(float(l[g1][g2]), float(l[g2][g1])) > args.significant_alignment_length:
                        those_that_saved_by_significant_length_param += 1
                        continue
                    else:
                        genome_hits_to_zero.append((g1, g2), )

        if len(genome_hits_to_zero):
            g1, g2 = genome_hits_to_zero[0]

            if those_that_saved_by_significant_length_param:
                additional_msg = "By the way, anvi'o saved %d weak hits becasue they were longer than the length of %d nts you\
                                  specified using the --significant-alignment-length parameter. " % \
                                        (those_that_saved_by_significant_length_param, args.significant_alignment_length)
            else:
                additional_msg = ""

            run.warning("THIS IS VERY IMPORTANT! You asked anvi'o to remove any hits between two genomes if the hit\
                         was produced by a weak alignment (which you defined as alignment fraction less than '%.2f').\
                         Anvi'o found %d hits between your %d genomes, and is about to set percent identity estimates emerging\
                         from those weak hits to 0 in all results. For instance, one of your genomes, '%s', was %.3f\
                         identical to '%s', another one of your genomes, but the aligned fraction of %s to %s was only %.3f\
                         and was below your threshold, and so the percent identity between them will be ignored for all\
                         downstream reports you will find in anvi'o tables and visualizations. %sAnvi'o kindly invites you\
                         to carefully think about potential implications of discarding hits based on an arbitrary alignment\
                         fraction, but does not judge you because it is not perfect either." % \
                                                (args.min_alignment_fraction, those_that_anvio_wanted_to_remove, len(d), g1, \
                                                 float(results['percentage_identity'][g1][g2]), g2, g1, g2, \
                                                 float(results['alignment_coverage'][g1][g2]), additional_msg))

        elif those_that_saved_by_significant_length_param:
             run.warning("THIS IS VERY IMPORTANT! You asked anvi'o to remove any hits between two genomes if the hit\
                          was produced by a weak alignment (which you defined as an alignment fraction less than '%.2f').\
                          Anvi'o found %d hits between your %d genomes, but the --significant-alignment-length parameter\
                          saved them all, because each one of them were longer than %d nts. So your filters kinda cancelled\
                          each other out. Just so you know." % \
                                                (args.min_alignment_fraction, those_that_anvio_wanted_to_remove, len(d),
                                                 args.significant_alignment_length))


        # time to zero those values out:
        for report_name in results:
            for g1, g2 in genome_hits_to_zero:
                results[report_name][g1][g2] = 0
                results[report_name][g2][g1] = 0

    clusterings = {}
    for report_name in results:
        try:
            clusterings[report_name] = clustering.get_newick_tree_data_for_dict(results[report_name],
                                                                                linkage=args.linkage,
                                                                                distance=args.distance)
        except:
            raise ConfigError("Bad news :/ Something went wrong while anvi'o was processing the output for\
                               '%s'. You can find the offending file if you search for the output file in\
                               the temporary output directory '%s'." % (report_name, os.path.join(temp_dir, 'output')))

    os.mkdir(args.output_dir)
    for report_name in results:
        output_path_for_report = os.path.join(args.output_dir, args.method + '_' + report_name)

        utils.store_dict_as_TAB_delimited_file(results[report_name], output_path_for_report + '.txt')
        with open(output_path_for_report + '.newick', 'w') as f:
            f.write(clusterings[report_name])

    if args.pan_db:
        utils.is_pan_db(args.pan_db)
        pan = db.DB(args.pan_db, anvio.__pan__version__)

        db_genome_names = set([])
        G = lambda g: pan.get_meta_value(g).strip().split(',')
        for genome_name in G('external_genome_names') + G('internal_genome_names'):
            db_genome_names.add(genome_name) if len(genome_name) else None

        found_only_in_db = db_genome_names.difference(genome_names)
        found_only_in_results = genome_names.difference(db_genome_names)

        if len(found_only_in_results) > 0:
            raise ConfigError("Some genome names found in ANI result, does not seem to be exists in the pan database. \
                Here are the list of them: " + ", ".join(list(found_only_in_results)))

        if len(found_only_in_db) > 0:
            run.warning("Some genome names found in pan database, does not seem to be exists in the ANI report. \
                anvi'o will add the ones that found in database anyway, but here is the list of missing ones: \
                " + ", ".join(list(found_only_in_db)))

        run.warning(None, header="MISC DATA MAGIC FOR YOUR PAN DB", lc='green')
        for report_name in results:
            target_data_group = 'ANI_%s' % (report_name)
            l_args = argparse.Namespace(pan_db=args.pan_db, just_do_it=True, target_data_group=target_data_group)
            TableForLayerAdditionalData(l_args, r=terminal.Run(verbose=False)).add(results[report_name], list(results[report_name].keys()))

            TableForLayerOrders(args, r=terminal.Run(verbose=False)).add({'ANI_' + report_name: {'data_type': 'newick',
                                                                  'data_value': clusterings[report_name]}})

            run.info_single("Additional data and order for %s are now in pan db" % target_data_group.replace('_', ' '), mc='green')

    if anvio.DEBUG:
        run.warning("The temp directory, %s, is kept. Please don't forget to clean it up\
                     later" % temp_dir, header="Debug")
    else:
        run.info_single('Cleaning up the temp directory (you can use `--debug` if you would\
                         like to keep it for testing purposes)', nl_before=1, nl_after=1)
        shutil.rmtree(temp_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Export sequences from external genomes and compute\
                                     ANI. If Pan Database is given anvi'o will write computed output\
                                     to misc data tables of Pan Database.")

    groupA = parser.add_argument_group('INPUT OPTIONS', "Tell anvi'o what you want.")
    groupA.add_argument(*anvio.A('internal-genomes'), **anvio.K('internal-genomes'))
    groupA.add_argument(*anvio.A('external-genomes'), **anvio.K('external-genomes'))

    groupB = parser.add_argument_group('OUTPUT OPTIONS', "Tell anvi'o where to store your results.")
    groupB.add_argument(*anvio.A('output-dir'), **anvio.K('output-dir', {'required': True }))
    groupB.add_argument(*anvio.A('pan-db'), **anvio.K('pan-db', {'required': False, 'help': "This is\
                        totally ooptional, but very useful when applicable. If you are running ANI for\
                        genomes for which you already have an anvi'o pangeome, then you can show where\
                        the pan database is and anvi'o would automatically add the results into the\
                        misc data tables of your pangenome. Those data can then be shown as ANI heatmaps\
                        on the pan interactive interface through the 'layers' tab."}))

    groupC = parser.add_argument_group('FILTERING HITS', "These filters can save lives. Ask Luke.")
    groupC.add_argument('--min-alignment-fraction', type=float, default=0.0, help="In some cases you may get\
                        high raw ANI estimates between two genomes that have nothing to do with each other\
                        simply because only a small fraction of their content may be aligned. This filter will\
                        eliminate ANI scores between two genomes if the alignment fraction is less than you\
                        deem trustable. When you set a value, anvi'o will go through the ANI results, and set\
                        percent identity scores between two genomes to 0 if the alignment fraction *between either\
                        of them* is less than the parameter described here. The default is 0.0, so every hit is\
                        reported, but you can go rebel, and choose any value between 0.0 and 1.0.")
    groupC.add_argument('--significant-alignment-length', type=int, default=None, help="So --min-alignmnet-fraction\
                        discards any hit that is coming from alignments that represent shorter fractions of genomes,\
                        but what if you still don't want to miss an alignment that is longer than an X number of\
                        nucleotides regardless of what fraction of the genome it represents? Well, this parameter is\
                        to recover things that may be lost due to --min-alignment-fraction parameter. Let's say,\
                        if you set --min-alignment-fraction to '0.05', and this parameter to '5000', anvi'o will keep\
                        hits from alignments that are longer than 5000 nts, EVEN IF THEY REPRESENT less than 5%% of \
                        a given genome pair. Basically if --min-alignment-fraction is your shield to protect yourself\
                        from incoming garbage, --significant-alignment-length is your chopstick to pick out those that\
                        may be interesting, and you are a true warrior here.")

    groupD = parser.add_argument_group('pyANI METHOD', "Tell anvi'o to tell pyANI what method you wish to use.")
    groupD.add_argument('--method', default='ANIb', type=str, help="Method for pyANI. The default is %(default)s.\
                         You must have the necessary binary in path for whichever method you choose. According to\
                         the pyANI help for v0.2.7 at https://github.com/widdowquinn/pyani, the method 'ANIm' uses\
                         MUMmer (NUCmer) to align the input sequences. 'ANIb' uses BLASTN+ to align 1020nt fragments\
                         of the input sequences. 'ANIblastall': uses the legacy BLASTN to align 1020nt fragments\
                         Finally, 'TETRA': calculates tetranucleotide frequencies of each input sequence",\
                         choices=['ANIm', 'ANIb', 'ANIblastall', 'TETRA'])

    groupE = parser.add_argument_group('HIERARCHICAL CLUSTERING', "Once pyANI is done with its magic, it reports ANI\
                         between genomes as distance matrix files, which can be clustered into nice looking dendrograms\
                         to display the relationships between genomes nicely (in the anvi'o interface and elsewhere).\
                         Here you can set the distance metric and the linkage algorithm for that.")
    groupE.add_argument(*anvio.A('distance'), **anvio.K('distance', {'help': 'The distance metric for the hierarchical \
                         clustering. The default is "%(default)s".'}))
    groupE.add_argument(*anvio.A('linkage'), **anvio.K('linkage', {'help': 'The linkage method for the hierarchical \
                         clustering. The default is "%(default)s".'}))

    groupF = parser.add_argument_group('OTHER IMPORTANT STUFF', "Yes. You're almost done.")
    groupF.add_argument(*anvio.A('num-threads'), **anvio.K('num-threads'))
    groupF.add_argument(*anvio.A('just-do-it'), **anvio.K('just-do-it'))
    groupF.add_argument(*anvio.A('log-file'), **anvio.K('log-file'))

    args = anvio.get_args(parser)

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-1)
