#!/usr/bin/env python
# -*- coding: utf-8
"""Get a block of sequence following a gene that mathces a functional annotation or HMM hit

    This program takes a contigs database, and searches for a function or HMM hit. It is useful
    when you want to study which genes are present around a gene of interest. It could be helpful
    when studying genomic loci that tend to some consereved and some variable genes.

  You want to play with it? This is how you could quickly test it:

  Downloaded the anvi'o data pack for the infant gut data, which is here:

    https://ndownloader.figshare.com/files/8252861

  Uunpack it and went into it:

    tar -zxvf INFANTGUTTUTORIAL.tar.gz && cd INFANT-GUT-TUTORIAL

  Then I run the program `anvi-export-gene-block` in the anvi'o master this way:

    anvi-export-locus -c CONTIGS.db \
                      -O OUTPUT \
                      --use-hmm \
                      --hmm-sources Campbell_et_al \
                      --search-term Ribosomal_L27 \
                      -n 24

  This will give you a fasta file OUTPUT.fa with the Ribosomal_L27 gene and the 24 genes that follow it.

  If all you want is the sequence of a gene matching a search pattern then use:
  for HMM: anvi-get-sequences-for-hmm-hits
  for functional annotations: anvi-search-functions
"""

import sys
import anvio
import argparse
import anvio.dbops as dbops
import anvio.hmmops as hmmops
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError

__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2018, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()


def main(args):
    A = lambda x: args.__dict__[x] if x in args.__dict__ else None
    # sanity check for output file
    if not args.separate_fasta:
        output_file_path = args.output_file_prefix + '.fa'
    else:
        output_file_path = args.output_file_prefix + '_0.fa'
    filesnpaths.is_output_file_writable(output_file_path, ok_if_exists=args.overwrite_output_destinations)
    reverse_complement_if_necessary = not A("never_reverse_complement")

    search_method = [x for x in [A("gene_caller_ids"), A("search_term")] if x is not None]
    if len(search_method) != 1:
        raise ConfigError("You must specify exacly one of the following: --gene-caller-ids or --search-term")

    contigs_db = dbops.ContigsSuperclass(args)
    contigs_db.init_contig_sequences()
    targets = []
    sources = []
    filtered_partial_hits = 0

    if args.list_hmm_sources:
        info_table = hmmops.SequencesForHMMHits(args.contigs_db).hmm_hits_info
        for source in info_table:
            t = info_table[source]
            run.info_single('%s [type: %s] [num genes: %d]' % (source, t['search_type'], len(t['genes'].split(','))))
        sys.exit(0)

    num_genes_list = [int(x) for x in args.num_genes.split(',')]
    if len(num_genes_list) > 2:
        raise ConfigError("The block size you provided, \"%s\", is not valid.\
                            The gene block size is defined by only one or two integers for either \
                            a block following the search match or a block preceding and following \
                            the search match respectively." % args.num_genes)

    if len(num_genes_list) == 1:
        num_genes_list = [0, num_genes_list[0]]

    if args.gene_caller_ids:
        matching_genes = list(utils.get_gene_caller_ids_from_args(args.gene_caller_ids, args.delimiter))

    elif args.use_hmm:
        # use HMM hits instead of functional annotation
        hmm_sources = set([s.strip() for s in args.hmm_sources.split(',')]) if args.hmm_sources else set([])

        targets.append('HMMs')
        # add hmm sources to the list of sources
        if hmm_sources:
            sources.extend(hmm_sources)
        else:
            sources.extend(contigs_db.hmm_sources_info.keys())

        s = hmmops.SequencesForHMMHits(args.contigs_db, sources = hmm_sources)
        hmm_hits = utils.get_filtered_dict(s.hmm_hits, 'gene_name', {args.search_term})
        # gene id's of the HMM hits
        matching_genes = [entry['gene_callers_id'] for entry in hmm_hits.values()]
    else:
        # use functional annotation
        contigs_db.init_functions()
        targets.append('functions')
        # add functional annotation sources to the list of sources
        sources.extend(contigs_db.gene_function_call_sources)

        foo, search_report = contigs_db.search_for_gene_functions([args.search_term], verbose=True)
        # gene id's of genes with the searched function
        matching_genes = [i[0] for i in search_report]

    # Multiple sources could annotate the same gene, so make sure the list is unique
    matching_genes = set(matching_genes)

    run.info('Matching genes', '%d genes matched your search' % len(matching_genes))

    entry_id = 0
    direction_dict = {'f': 1, 'r': -1}
    for gene_id in matching_genes:
        premature = False
        reverse_complemented = False

        # for each gene hit take the sequence that follows in the direction of transcription
        gene_dict = contigs_db.genes_in_contigs_dict[gene_id]
        direction = gene_dict['direction']
        gene_1 = gene_id - num_genes_list[0] * direction_dict[direction]
        gene_2 = gene_id + num_genes_list[1] * direction_dict[direction]
        first_gene_of_the_block = min(gene_1, gene_2)
        last_gene_of_the_block = max(gene_1, gene_2)
        # getting the ids for the first and last genes in the contig
        contig_of_gene = contigs_db.genes_in_contigs_dict[gene_id]['contig']
        genes_in_contig_sorted = sorted(list(contigs_db.contig_name_to_genes[contig_of_gene]))
        last_gene_in_contig = genes_in_contig_sorted[-1][0]
        first_gene_in_contig = genes_in_contig_sorted[0][0]

        if last_gene_of_the_block > last_gene_in_contig:
            # The requested block exceeds the contig length, then set `last_gene_of_the_block` to `last_gene_in_contig`
            # FIXME: there must be a better way to do this.
            last_gene_of_the_block = last_gene_in_contig
            premature = True

        if first_gene_of_the_block < first_gene_in_contig:
            # take the first gene of the contig. FIXME: also might be a better way to do this
            first_gene_of_the_block = first_gene_in_contig
            premature = True

        start = contigs_db.genes_in_contigs_dict[first_gene_of_the_block]['start']
        stop = contigs_db.genes_in_contigs_dict[last_gene_of_the_block]['stop']
        block_sequence = contigs_db.contig_sequences[gene_dict['contig']]['sequence'][start:stop]

        if direction == 'r' and reverse_complement_if_necessary:
            block_sequence = utils.rev_comp(block_sequence)
            reverse_complemented = "True"
        else:
            reverse_complemented = "False"

        if not (args.remove_partial_hits and premature):
            # don't save if it's a premature result AND the user asked to remove partial hits
            entry_id += 1
            if args.separate_fasta:
                # create a separate fasta per match
                output_file_path = args.output_file_prefix + '_' + str(entry_id).zfill(4) + '.fa'
                f = open(output_file_path, 'w')
            elif entry_id == 1:
                f = open(output_file_path, 'w')

            block_header = contigs_db.a_meta['project_name'] + ' ' +\
                           '|'.join([
                                     'target:%s' % ', '.join(targets),
                                     'sources:%s' % ', '.join(sources),
                                     'query:%s' % args.search_term,
                                     'hit_contig:%s' % gene_dict['contig'],
                                     'hit_gene_callers_id:%s' % str(gene_id),
                                     'locus:%s,%s' % (str(first_gene_of_the_block), str(last_gene_of_the_block)),
                                     'nt_positions_in_contig:%s:%s' % (str(start), str(stop)),
                                     'premature:%s' % str(premature),
                                     'reverse_complemented:%s' % str(reverse_complemented)])

            f.write('>%s\n' % block_header)
            f.write('%s\n' % block_sequence)

            if args.separate_fasta:
                # if separate fasta files are created then need to close it now
                f.close()
        else:
            filtered_partial_hits += 1

    if filtered_partial_hits and args.remove_partial_hits:
        run.warning("%s hits were filtered because they were partial." % filtered_partial_hits)

    if not args.separate_fasta and entry_id > 0:
        # if the matches are combined into one fasta then we only close it now
        f.close()

    if entry_id == 0:
        run.warning('There were no hits for your search so no output file will be created.')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Search for a function or HMM hit and for each matching gene \
                                                    get the sequence including a block of genes the around the match.\
                                                    The output will be written to a fasta file (or multiple files, see\
                                                    --separate-fasta option below. The headers of the sequences in the\
                                                    fasta file hold some information about the gene.)")

    groupA = parser.add_argument_group('Essential INPUT')
    groupA.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))
    groupA.add_argument('-n', '--num-genes',type=str, help='For each match (to the function, or HMM that was searched) a sequence which includes \
                            a block of genes will be saved. The block could include either genes only in the forward direction of the gene (defined \
                            according to the direction of transcription of the gene) or reverse or both. \
                            If you wish to get both direction use a comma (no spaces) to define the block \
                            For example, \"-n 4,5\" will give you four genes before and five genes after. \
                            Whereas, \"-n 5\" will give you five genes after (in addition to the gene that matched). \
                            To get only genes preceeding the match use \"-n 5,0\". \
                            If the number of genes requested exceeds the length of the contig, then the output \
                            will include the sequence until the end of the contig.', required=True)
    groupB =  parser.add_argument_group('Additional essential INPUT - OPTION 1', "Search according to either HMM or functional annotations")
    groupB.add_argument('-s', '--search-term', help='Search term.')
    groupC =  parser.add_argument_group('Additional essential INPUT - OPTION 2', "Search specific gene id's")
    groupC.add_argument(*anvio.A('gene-caller-ids'), **anvio.K('gene-caller-ids'))
    groupC.add_argument(*anvio.A('delimiter'), **anvio.K('delimiter'))

    groupD = parser.add_argument_group('THE OUTPUT', "Where should the output go. It will be one FASTA file with all matches \
                                       or one FASTA per match (see --separate-fasta)") 
    groupD.add_argument(*anvio.A('output-file-prefix'), **anvio.K('output-file-prefix', {'required': True}))

    groupE = parser.add_argument_group('ADDITIONAL STUFF', "Flags and parameters you can set according to your need")
    groupE.add_argument('--separate-fasta', default = False, action='store_true', help='Split each match to a separate FASTA file.')
    groupE.add_argument('--use-hmm', default = False, action='store_true', help='Use HMM hits instead of functional annotations. \
                            If you choose this option, you must also say which HMM source to use.')
    groupE.add_argument(*anvio.A('hmm-sources'), **anvio.K('hmm-sources'))
    groupE.add_argument(*anvio.A('list-hmm-sources'), **anvio.K('list-hmm-sources'))
    groupE.add_argument(*anvio.A('overwrite-output-destinations'), **anvio.K('overwrite-output-destinations'))
    groupE.add_argument(*anvio.A('remove-partial-hits'), **anvio.K('remove-partial-hits'))
    groupE.add_argument(*anvio.A('never-reverse-complement'), **anvio.K('never-reverse-complement'))

    args = anvio.get_args(parser)

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-1)
