#!/usr/bin/env python
# -*- coding: utf-8

import os
import numpy
import argparse
from collections import Counter
from numpy import log2 as log

import anvio
import anvio.tables as t
import anvio.dbops as dbops
import anvio.dictio as dictio
import anvio.terminal as terminal


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


pp = terminal.pretty_print
progress = terminal.Progress()
run = terminal.Run(width = 30)

states_dict = {'AA': 'a',
               'TT': 't',
               'CC': 'c',
               'GG': 'g',
               'NN': 'n',
               'AT': 'w',
               'TA': 'w',
               'AC': 'm',
               'CA': 'm',
               'CT': 'y',
               'TC': 'y',
               'AG': 'r',
               'GA': 'r',
               'CG': 's',
               'GC': 's',
               'GT': 'k',
               'TG': 'k'}

valid_chars = set(states_dict.values()) 
reverse_states = dict([(states_dict[v], ''.join(set(v))) for v in set([''.join(sorted(k)) for k in list(states_dict.keys())])])


def entropy(l):
    E_Cs = []
    for char in valid_chars:
        P_C = (l.count(char) * 1.0 / len(l)) + 0.0000000000000000001
        E_Cs.append(P_C * log(P_C))
   
    # return un-weighted entropy
    return -(sum(E_Cs))


class ContigVariability:
    def __init__(self, split_name, split_data, split_sequence, num_positions_from_each_split = 2):
        self.split_name = split_name
        self.split_data = split_data
        self.split_sequence = split_sequence

        self.layer_names = sorted(self.split_data.keys())
        self.layer_index = dict(list(zip(self.layer_names, list(range(0, len(self.layer_names))))))

        self.positions_dict = {}
        self.num_positions_from_each_split = num_positions_from_each_split
        self.positions_selected = None

        self.TOP = lambda key, num: [x[1] for x in sorted([(self.positions_dict[p][key], p) for p in self.positions_selected], reverse = True)[:num]]


    def populate_positions_dict(self):
        '''We take a split, and create a record of each variable position in it.'''
        counter = Counter()
        for variable_nt_positions_in_layer in [list(s['competing_nucleotides'].keys()) for s in list(self.split_data.values())]:
            for pos in variable_nt_positions_in_layer:
                counter[pos] += 1

        # initiate a null dict only with the information regarding how many times
        # a position is observed across layers with a variability index > 0
        for pos in counter:
            self.positions_dict[pos] = {'occurrence': counter[pos],
                                        'variability': 0.0,
                                        'mean_coverage': 0.0,
                                        'coverages': [],
                                        'contents': [],
                                        'avatars': [],
                                        'entropy': 0.0 }

        # fill in more information to finalize the dict
        for p in self.positions_dict:
            for layer_name in self.layer_names:
                layer = self.split_data[layer_name]
                self.positions_dict[p]['variability'] += layer['variability'][p]
                self.positions_dict[p]['coverages'].append(layer['coverage'][p])

                if p in layer['competing_nucleotides']:
                    content = layer['competing_nucleotides'][p]
                else:
                    if layer['coverage'][p]:
                        content = self.split_sequence[p] + self.split_sequence[p]
                    else:
                        content = 'NN'
                self.positions_dict[p]['contents'].append(content)
                self.positions_dict[p]['avatars'].append(states_dict[content])


            avatar_set = set(self.positions_dict[p]['avatars'])
            avatar_conversion = dict(list(zip(avatar_set, list(range(1, len(avatar_set) + 1)))))
            self.positions_dict[p]['identities'] = [avatar_conversion[a] for a in self.positions_dict[p]['avatars']]
            self.positions_dict[p]['mean_coverage'] = numpy.mean(self.positions_dict[p]['coverages'])
            self.positions_dict[p]['entropy'] = entropy(''.join(self.positions_dict[p]['avatars']))

        # remove any position with 0 entropy:
        positions_with_zero_entropy = [p for p in self.positions_dict if self.positions_dict[p]['entropy'] < 0.00000001]
        for pos in positions_with_zero_entropy:
            self.positions_dict.pop(pos)

        self.positions_selected = list(self.positions_dict.keys())


    def position_selection_heuristics(self):
        '''The main purpose of this function is to normalize the number of positions across
           multiple splits by selecting only top `self.num_positions_from_each_split` positions
           fom each split.
        '''

        # first order all positions in a split based on entropy
        self.positions_selected = self.TOP('entropy', self.num_positions_from_each_split * 3)
        # first order them by variability
        self.positions_selected = self.TOP('variability', self.num_positions_from_each_split * 2)
        # then order by coverage
        self.positions_selected = self.TOP('mean_coverage', self.num_positions_from_each_split)


    def analyze_split_summary(self):
        # get a dictionary with all positions in all layers with their occurrences
        self.populate_positions_dict()

        # set/select interesting positions
        self.position_selection_heuristics()


    def text_report(self):
        run.warning('', lc = 'crimson')
        for pos in self.positions_selected:
            run.info('%d (entropy: %f)' % (pos, self.positions_dict[pos]['entropy']), None, header = True)
            for layer_name in self.layer_names:
                i = self.layer_index[layer_name]
                layer = self.split_data[layer_name]
                content = self.positions_dict[pos]['contents'][i]
                run.info(layer_name, '%s %s %f %d' % (content, states_dict[content], layer['variability'][pos], self.positions_dict[pos]['coverages'][i]))


class VariabilityWrapper:
    def __init__(self, args):
        self.args = args
        self.samples_dict = None
        self.splits = []
        self.split_positions = {}
        self.layer_names = []
        self.units = []

    def analyze(self):
        run_files_path = os.path.dirname(os.path.abspath(args.summary_dict))
        summary_index = dictio.read_serialized_object(args.summary_dict)
        num_positions_from_each_split = args.num_positions_from_each_split
        min_scatter = args.min_scatter
        
        splits = [c.strip() for c in open(args.splits_of_interest).readlines()]
        num_splits = len(splits)
        run.info('Splits', '%d splits found' % (num_splits))
        
        contigs_db = dbops.ContigsDatabase(args.contigs_db)
        splits_info = contigs_db.db.get_table_as_dict(t.splits_info_table_name)
        contig_sequences = contigs_db.db.get_table_as_dict(t.contig_sequences_table_name)
        contigs_db.disconnect()


        progress.new('Analyzing splits')
        for i in range(0, num_splits):
            split_name = splits[i]
            progress.update('%d of %d' % (i + 1, num_splits))
            d = dictio.read_serialized_object(os.path.join(run_files_path, summary_index[split_name]))

            parent = splits_info[split_name]['parent']
            start = splits_info[split_name]['start']
            end = splits_info[split_name]['end']

            split_sequence = contig_sequences[parent]['sequence'][start:end]

            w.add_split(split_name, d, split_sequence, num_positions_from_each_split)

        progress.update('Generating output ...')
        w.create_TAB_delim_file(args.output_file, min_scatter)
        progress.end()
        run.info('Output matrix', args.output_file)


    def add_split(self, split_name, split_data, split_sequence, num_positions_from_each_split = 2):
        v = ContigVariability(split_name, split_data, split_sequence, num_positions_from_each_split)
        v.analyze_split_summary()

        self.splits.append(split_name)
        self.split_positions[split_name] = v.positions_selected

        # create an empty samples_dict
        if not self.samples_dict:
            self.samples_dict ={}
            for layer_name in v.layer_names:
                self.samples_dict[layer_name] = {}
            self.layer_names = v.layer_names

        for pos in v.positions_selected:
            unit = '%s_pos_%d' % (split_name, pos)
            self.units.append(unit)

            for layer_name in v.layer_names:
                i = v.layer_index[layer_name]
                self.samples_dict[layer_name][unit] = v.positions_dict[pos]['avatars'][i]


    def create_TAB_delim_file(self, path, min_scatter = 1):
        # units to discard due to min_scatter:
        units_to_discard = set([])

        if min_scatter > 1:
            for unit in self.units:
                values = []

                for layer_name in self.layer_names:
                    values.append(self.samples_dict[layer_name][unit])

                if len(set(values)) > 1 and Counter(values).most_common()[1][1] < min_scatter:
                    units_to_discard.add(unit)

        if len(units_to_discard):
            units = [unit for unit in self.units if unit not in units_to_discard]
        else:
            units = self.units

        output = open(path, 'w')
        output.write('\t'.join(['layers'] + units) + '\n')
        for layer_name in self.layer_names:
            values = []
            for unit in units:
                values.append(self.samples_dict[layer_name][unit])
            output.write('\t'.join([layer_name] + [str(reverse_states[v]) for v in values]) + '\n')
        output.close()




##############################################################################

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate Variability Matrix')

    parser.add_argument('summary_dict', metavar = 'SUMMARY_DICT', default = None,
                        help = 'Summary file')

    parser.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))
    parser.add_argument(*anvio.A('splits-of-interest'), **anvio.K('splits-of-interest', {'required': True}))
    parser.add_argument(*anvio.A('samples-of-interest'), **anvio.K('samples-of-interest'))
    parser.add_argument(*anvio.A('num-positions-from-each-split'), **anvio.K('num-positions-from-each-split'))
    parser.add_argument(*anvio.A('min-scatter'), **anvio.K('min-scatter'))
    parser.add_argument(*anvio.A('min-ratio-of-competings-nts'), **anvio.K('min-ratio-of-competings-nts'))
    parser.add_argument(*anvio.A('output-file'), **anvio.K('output-file', {'default': 'variability.txt'}))

    args = parser.parse_args()
    
    
    w = VariabilityWrapper(args)
    w.analyze()
