#!/usr/bin/env python
# -*- coding: utf-8

import os
import sys

import anvio
import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

import IlluminaUtils.lib.fastqlib as u

from anvio.errors import ConfigError, FilesNPathsError

__author__ = "Developers of anvi'o (see AUTHORS.txt)"
__copyright__ = "Copyleft 2015-2019, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()
pp = terminal.pretty_print


def main(args):
    A = lambda x: args.__dict__[x] if x in args.__dict__ else None
    input_fastq_file_paths = A('fastq_files')
    match_sequence = A('match_sequence')
    sample_name = A('sample_name')
    output_directory_path = A('output_directory') or os.path.abspath('.')
    min_remainder_length = A('min_remainder_length') or 60
    report_raw = A('report_raw')
    stop_after = A('stop_after')

    for input_fastq_file_path in input_fastq_file_paths:
        filesnpaths.is_file_exists(input_fastq_file_path)

    utils.check_sample_id(sample_name)
    match_sequence = match_sequence.upper()
    filesnpaths.gen_output_directory(output_directory_path)

    run.info('Input FASTQ Files', ', '.join(input_fastq_file_paths))
    run.info('Sample name', sample_name)
    run.info('Match sequence', match_sequence)
    run.info('Min remainder length', min_remainder_length)
    run.info('Output directory', output_directory_path)
    run.info('Report raw', report_raw)
    if stop_after:
        run.info('Stop after', stop_after, mc='red')

    read_counter = 0
    hit_counter = 0
    sequences = []
    match_sequence_length = len(match_sequence)
    num_input_files = len(input_fastq_file_paths)

    progress.new("Tick tock")
    progress.update('...')
    for i in range(0, num_input_files):
        input_fastq_file_path = input_fastq_file_paths[i]
        input_fastq = u.FastQSource(input_fastq_file_path, compressed=input_fastq_file_path.endswith('.gz'))

        while input_fastq.next(raw=True) and (hit_counter < stop_after if stop_after else True):
            read_counter += 1

            if read_counter % 5000 == 0:
                progress.update('File %d of %d / Reads: %s / Hits: %d' % (i + 1, num_input_files, pp(read_counter), hit_counter))

            for seq in [input_fastq.entry.sequence, utils.rev_comp(input_fastq.entry.sequence)]:
                pos = seq.find(match_sequence)

                if pos < 0:
                    continue

                if len(seq[pos + match_sequence_length:]) < min_remainder_length:
                    continue

                sequences.append((pos, seq), )
                hit_counter += 1

                if anvio.DEBUG:
                    progress.end()
                    print("\n%s [%s] %s" % (seq[:pos], match_sequence, seq[pos+match_sequence_length:]))
                    progress.new("Tick tock")
                    progress.update('File %d of %d / Reads: %s / Hits: %d' % (i, num_input_files, pp(read_counter), hit_counter))

    progress.end()

    run.info('Total number of reads analyzed', read_counter, nl_before=1)

    if not len(sequences):
        run.info_single('No hits were found :/', mc='red', nl_before=1, nl_after=1)
        sys.exit()

    seq_lengths_after_match = [len(s[1][s[0]+match_sequence_length:]) for s in sequences]
    run.info('Total number of hits found', hit_counter, mc='green')
    run.info('Shortest after match', min(seq_lengths_after_match))
    run.info('Longest after match', max(seq_lengths_after_match))
    run.info('Average length after match', sum(seq_lengths_after_match) / len(sequences))

    progress.new("Generating the raw hits file")
    progress.update('...')
    output_file_path = os.path.join(output_directory_path, '%s-%s-HITS.fa' % (sample_name, match_sequence))
    with open(output_file_path, 'w') as output:
        counter = 1
        for hit in sequences:
            output.write('>%s_%.05d\n%s\n' % (sample_name, counter, hit[1]))
            counter += 1
    progress.end()

    run.info('Raw output', output_file_path)

    if report_raw:
        sys.exit()

    progress.new("Generating the fancy hits files")
    progress.update('...')
    trimmed_output_file_path = os.path.join(output_directory_path, '%s-%s-HITS-TRIMMED.fa' % (sample_name, match_sequence))
    gapped_output_file_path = os.path.join(output_directory_path, '%s-%s-HITS-WITH-GAPS.fa' % (sample_name, match_sequence))
    with open(trimmed_output_file_path, 'w') as trimmed, open(gapped_output_file_path, 'w') as gapped:
        max_seq_length = max(seq_lengths_after_match) + match_sequence_length
        min_seq_length = min(seq_lengths_after_match) + match_sequence_length
        counter = 1
        for hit in sequences:
            sequence = hit[1][hit[0]:]

            sequence = sequence + '-' * (max_seq_length - len(sequence))
            gapped.write('>%s_%.05d\n%s\n' % (sample_name, counter, sequence))

            sequence = sequence[:min_seq_length]
            trimmed.write('>%s_%.05d\n%s\n' % (sample_name, counter, sequence))

            counter += 1
    progress.end()

    run.info('Trimmed output', trimmed_output_file_path)
    run.info('Gapped output', gapped_output_file_path)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="You give this program one or more FASTQ files and a short sequence,\
                                                  and it returns all short reads from the FASTQ file that matches to it.\
                                                  The purpose of this is to get back short reads that may be extending into\
                                                  hypervariable regions of genomes, resulting a decreased mappability of\
                                                  short reads in the metagenome given a reference. You often see those areas\
                                                  of genomes as significant dips in coverage, and in most cases with a large\
                                                  number of SNVs. When you provide the downstream conserved sequence, this\
                                                  program allows you to take a better look at those regions at the short\
                                                  read level without any mapping.")

    parser.add_argument('fastq_files', nargs='+', help='One or more FASTQ formatted files', metavar='FASTQ_FILES')
    parser.add_argument('--match-sequence', required=True, metavar='SHORT SEQUENCE',
                        help="Short sequence to look for..")
    parser.add_argument('-m', '--min-remainder-length', metavar='INT', type=int, default=60,
                        help="Minimum lenght of the remainder of the read after the match. If your short read\
                              is XXXMMMMMMYYYYYYYYYYYYYY, where M indicates nucleotides of matchhing sequence,\
                              min remainder length is len(Y).")
    parser.add_argument('-s', '--sample-name', required=True, metavar='NAME',
                        help="A short sample name (use a single word without spaces or fancy chars)")
    parser.add_argument('-O', '--output-directory', metavar='PATH',
                        help="Output directory for results to be stored. The default is the current\
                              working directory.")
    parser.add_argument('--report-raw', action="store_true",
                        help="Just report them raw. Don't bother trimming.")
    parser.add_argument('--stop-after', metavar='INT', type=int, default=0,
                        help="Stop after X number of hits because who needs data.")

    args = anvio.get_args(parser)

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
