#!/usr/bin/env python
# -*- coding: utf-8

import os
import sys

import anvio
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

import IlluminaUtils.lib.fastqlib as u

from anvio.errors import ConfigError, FilesNPathsError

__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2019, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__description__ = ("You give this program one or more FASTQ files and a short sequence, "
                   "and it returns all short reads from the FASTQ file that matches to it. "
                   "The purpose of this is to get back short reads that may be extending into "
                   "hypervariable regions of genomes, resulting a decreased mappability of "
                   "short reads in the metagenome given a reference. You often see those areas "
                   "of genomes as significant dips in coverage, and in most cases with a large "
                   "number of SNVs. When you provide the downstream conserved sequence, this "
                   "program allows you to take a better look at those regions at the short "
                   "read level without any mapping.")


run = terminal.Run()
progress = terminal.Progress()
pp = terminal.pretty_print


def main(args):
    A = lambda x: args.__dict__[x] if x in args.__dict__ else None
    input_fastq_file_paths = A('fastq_files')
    match_sequence = A('match_sequence')
    sample_name = A('sample_name')
    output_directory_path = A('output_directory') or os.path.abspath('.')
    min_remainder_length = A('min_remainder_length') or 60
    report_raw = A('report_raw')
    stop_after = A('stop_after')

    for input_fastq_file_path in input_fastq_file_paths:
        filesnpaths.is_file_exists(input_fastq_file_path)

    utils.check_sample_id(sample_name)
    match_sequence = match_sequence.upper()
    match_sequence_rc = utils.rev_comp(match_sequence)
    filesnpaths.gen_output_directory(output_directory_path)

    run.info('Input FASTQ Files', ', '.join(input_fastq_file_paths))
    run.info('Sample name', sample_name)
    run.info('Match sequence', match_sequence)
    run.info('Min remainder length', min_remainder_length)
    run.info('Output directory', output_directory_path)
    run.info('Report raw', report_raw)
    if stop_after:
        run.info('Stop after', stop_after, mc='red')

    read_counter = 0
    hit_counter = 0
    hit_in_rc = 0
    sequences = []
    match_sequence_length = len(match_sequence)
    num_input_files = len(input_fastq_file_paths)

    progress.new("Tick tock")
    progress.update('...')
    for i in range(0, num_input_files):
        input_fastq_file_path = input_fastq_file_paths[i]
        input_fastq = u.FastQSource(input_fastq_file_path, compressed=input_fastq_file_path.endswith('.gz'))

        while input_fastq.next(raw=True) and (hit_counter < stop_after if stop_after else True):
            read_counter += 1

            if read_counter % 10000 == 0:
                progress.update('File %d of %d / Reads: %s / Hits: %d (in RC: %d)' % (i + 1, num_input_files, pp(read_counter), hit_counter, hit_in_rc))

            found_in_RC = False

            seq = input_fastq.entry.sequence
            pos = seq.find(match_sequence)

            if pos < 0:
                # it not in it. how about the the reverse complement of it?
                if seq.find(match_sequence_rc) >= 0:
                    # aha. the reverse complement sequence that carries our match found.
                    # now we will reverse complement the long sequence and update the pos
                    # and will continue as if nothing happened
                    found_in_RC = True
                    hit_in_rc += 1
                    seq = utils.rev_comp(seq)
                    pos = seq.find(match_sequence)

            if pos < 0:
                continue

            if len(seq[pos + match_sequence_length:]) < min_remainder_length:
                continue

            sequences.append((pos, seq), )

            hit_counter += 1

            if anvio.DEBUG:
                progress.end()
                print("\n%s| %s [%s] %s" % ('RC ' if found_in_RC else '   ', seq[:pos], match_sequence, seq[pos+match_sequence_length:]))
                progress.new("Tick tock")
                progress.update('File %d of %d / Reads: %s / Hits: %d (in RC: %d)' % (i, num_input_files, pp(read_counter), hit_counter, hit_in_rc))

    progress.end()

    run.info('Total number of reads analyzed', read_counter, nl_before=1)

    if not len(sequences):
        run.info_single('No hits were found :/', mc='red', nl_before=1, nl_after=1)
        sys.exit()

    seq_lengths_after_match = [len(s[1][s[0]+match_sequence_length:]) for s in sequences]
    run.info('Total number of hits found', hit_counter, mc='green')
    run.info('Shortest after match', min(seq_lengths_after_match))
    run.info('Longest after match', max(seq_lengths_after_match))
    run.info('Average length after match', sum(seq_lengths_after_match) / len(sequences))

    progress.new("Generating the raw hits file")
    progress.update('...')
    output_file_path = os.path.join(output_directory_path, '%s-%s-HITS.fa' % (sample_name, match_sequence))
    with open(output_file_path, 'w') as output:
        counter = 1
        for hit in sequences:
            output.write('>%s_%.05d\n%s\n' % (sample_name, counter, hit[1]))
            counter += 1
    progress.end()

    run.info('Raw output', output_file_path)

    if report_raw:
        sys.exit()

    progress.new("Generating the fancy hits files")
    progress.update('...')
    trimmed_output_file_path = os.path.join(output_directory_path, '%s-%s-HITS-TRIMMED.fa' % (sample_name, match_sequence))
    gapped_output_file_path = os.path.join(output_directory_path, '%s-%s-HITS-WITH-GAPS.fa' % (sample_name, match_sequence))
    with open(trimmed_output_file_path, 'w') as trimmed, open(gapped_output_file_path, 'w') as gapped:
        max_seq_length = max(seq_lengths_after_match) + match_sequence_length
        min_seq_length = min(seq_lengths_after_match) + match_sequence_length
        counter = 1
        for hit in sequences:
            sequence = hit[1][hit[0]:]

            sequence = sequence + '-' * (max_seq_length - len(sequence))
            gapped.write('>%s_%.05d\n%s\n' % (sample_name, counter, sequence))

            sequence = sequence[:min_seq_length]
            trimmed.write('>%s_%.05d\n%s\n' % (sample_name, counter, sequence))

            counter += 1
    progress.end()

    run.info('Trimmed output', trimmed_output_file_path)
    run.info('Gapped output', gapped_output_file_path)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description=__description__)

    parser.add_argument('fastq_files', nargs='+', help='One or more FASTQ formatted files', metavar='FASTQ_FILES')
    parser.add_argument('--match-sequence', required=True, metavar='SHORT SEQUENCE',
                        help="Short sequence to look for..")
    parser.add_argument('-m', '--min-remainder-length', metavar='INT', type=int, default=60,
                        help="Minimum lenght of the remainder of the read after the match. If your short read\
                              is XXXMMMMMMYYYYYYYYYYYYYY, where M indicates nucleotides of matchhing sequence,\
                              min remainder length is len(Y). Default is %(default)d.")
    parser.add_argument('-s', '--sample-name', required=True, metavar='NAME',
                        help="A short sample name (use a single word without spaces or fancy chars)")
    parser.add_argument('-O', '--output-directory', metavar='PATH',
                        help="Output directory for results to be stored. The default is the current\
                              working directory.")
    parser.add_argument('--report-raw', action="store_true",
                        help="Just report them raw. Don't bother trimming.")
    parser.add_argument('--stop-after', metavar='INT', type=int, default=0,
                        help="Stop after X number of hits because who needs data.")

    args = anvio.get_args(parser)

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
