#!/usr/bin/env python
# -*- coding: utf-8

import sys
import random
import argparse
import configparser

import anvio
import anvio.terminal as terminal

from anvio.errors import ConfigError, FilesNPathsError


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()
pp = terminal.pretty_print

bases = ['A', 'T', 'C', 'G']


class Configuration:
    def __init__(self, config):
        self.output_file = config.get('general', 'output_file')
        self.short_read_length = int(config.get('general', 'short_read_length'))
        self.coverage = int(config.get('general', 'coverage'))
        self.contig = str(config.get('general', 'contig'))

        self.SNVs = []

        for section in [s for s in config.sections() if s != 'general']:
            location = int(section)
            ratio = float(user_config.get(section, 'ratio'))

            if ratio > 1:
                raise ConfigError('Ratio cannot be more than 1 (error for SNV location %d)' % location)
            if ratio <= 0:
                raise ConfigError('Ratio cannot be 0 or less (error for SNV location %d)' % location)

            if location >= len(self.contig):
                raise ConfigError('SNV position at %d for a contig that is %d nts long? Really?' % (location, len(self.contig)))

            self.SNVs.append((location, ratio),)


def main(config):
    run = terminal.Run(width = 15)
    progress = terminal.Progress()

    sequences = {}

    x = config.short_read_length
    c = config.coverage

    progress.new('Generating short reads')

    L = len(config.contig)

    av_num_short_reads_needed = L / x * c

    for i in range(0, av_num_short_reads_needed):
        if (i + 1) % 100 == 0:
            progress.update('Entry %s of %s ...' % (pp(i), pp(av_num_short_reads_needed)))

        start_pos = random.randint(0, L - x)
        short_read = config.contig[start_pos:start_pos + x]

        sequences[i] = {'sequence': short_read, 'start': start_pos, 'stop': start_pos + x, 'num_SNVs': 0}

    progress.end()

    progress.new('Introducing SNVs')

    for snv_location, ratio in config.SNVs:
        progress.update('Working on location %d with ratio of %.2f' % (snv_location, ratio))

        matching_entries = []

        for entry_id in sequences:
            e = sequences[entry_id]

            if snv_location >= e['start'] and snv_location < e['stop']:
                matching_entries.append(entry_id)

        entries_to_mutate = random.sample(matching_entries, int(round(ratio * len(matching_entries))))
        current_base = config.contig[snv_location].upper()
        new_base = bases[(bases.index(current_base) + 1) % 4]

        for entry_id in entries_to_mutate:
            position_in_sequence_to_replace = snv_location - sequences[entry_id]['start']
            sequences[entry_id]['sequence'] = sequences[entry_id]['sequence'][:position_in_sequence_to_replace] + new_base + sequences[entry_id]['sequence'][position_in_sequence_to_replace + 1:]
            sequences[entry_id]['num_SNVs'] += 1

    output = open(config.output_file, 'w')
    for entry_id in sequences:
        s = sequences[entry_id]
        output.write('>%s\n' % '|'.join(['%d' % entry_id, 'start:%d' % s['start'], 'stop:%d' % s['stop'], 'num_SNVs:%d' % s['num_SNVs']]))
        output.write('%s\n' % s['sequence'])

    progress.end()


    run.info('Fasta output', config.output_file) 


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate short reads from contigs')
    parser.add_argument('configuration', metavar = 'CONFIG_FILE', 
                                        help = 'Configuration file')

    args = parser.parse_args()
    user_config = configparser.ConfigParser()
    user_config.read(args.configuration)


    try:
        config = Configuration(user_config)
        main(config)
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
