#!/usr/bin/env python
# -*- coding: utf-8

import os
import sys

import anvio
import anvio.db as db
import anvio.dbops as dbops
import anvio.tables as t
import anvio.terminal as terminal

from anvio.errors import ConfigError, FilesNPathsError
from anvio.utils import gen_gexf_network_file


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()


class NetworkDescriptonSamples:
    def __init__(self):
        self.profile_db_path = None
        self.contigs_db_path = None
        self.use_named_functions_only = False
        self.functions = {'labels': {}}

        self.P = lambda x: os.path.join(os.path.dirname(self.profile_db_path), x)


    def init(self):
        dbops.is_profile_db_and_contigs_db_compatible(self.profile_db_path, self.contigs_db_path)

        self.profile_db = db.DB(self.profile_db_path, t.profile_db_version)
        self.contigs_db = db.DB(self.contigs_db_path, t.contigs_db_version)

        table_names = self.profile_db.get_table_names()

        if not int(self.profile_db.get_meta_value('merged')):
            raise ConfigError("The profile database describes a single run. Current implementation of this\
                               program restricts its use to merged runs. Sorry :/")

        if t.gene_coverages_table_name not in table_names:
            raise ConfigError("There is no '%s' table in the profile database. This does not really make\
                                      any sense :/" % t.gene_coverages_table_name)

        progress.new('Init')
        progress.update('Reading genes table from %s ...' % self.profile_db_path)
        self.genes_dict = self.profile_db.get_table_as_dict(t.gene_coverages_table_name)
        self.samples = sorted(list(set([r['sample_id'] for r in list(self.genes_dict.values())])))
        self.genes = sorted(list(set([r['gene_callers_id'] for r in list(self.genes_dict.values())])))

        progress.update('Reading annotations table from %s  ...' % self.contigs_db_path)
        self.genes_in_contigs = self.contigs_db.get_table_as_dict(t.genes_in_contigs_table_name)

        # read functions.
        progress.update('Sifting through genes with non-hypothetical functions ...')

        for gene in self.genes:
            # FIXME: The following is a temporary workaround, and the proper waay to address this 
            #        is to read the functions for each gene from the appropriate resource:
            self.functions['labels'][gene] = 'Unknown'

        self.samples_dict = self.get_samples_dict(self.samples, self.genes_dict)
        progress.end()
        run.info('init', 'Genes database initialized for %d genes in %d samples.' % (len(self.genes), len(self.samples)))


    def generate_genes_network(self):
        network_desc_output_path = self.P('SAMPLE-GENE-NETWORK.gexf')
        progress.new('Processing')
        progress.update('Generating network description for %d genes across %d samples ... ' % (len(self.genes), len(self.samples_dict)))
        gen_gexf_network_file(self.genes, self.samples_dict, network_desc_output_path)
        progress.end()
        run.info('network with genes', network_desc_output_path)


    def generate_functions_network(self):
        genes_with_functions = sorted(self.functions['labels'].keys())
        progress.new('Processing')
        progress.update('Generating network description for %d genes w/functions across %d samples ... ' % (len(genes_with_functions), len(self.samples_dict)))

        network_desc_output_path = self.P('SAMPLE-FUNCTION-NETWORK.gexf')
        gen_gexf_network_file(genes_with_functions, self.samples_dict,
                              network_desc_output_path, unit_mapping_dict = self.functions)
        progress.end()
        run.info('network with functions', network_desc_output_path)



    def get_samples_dict(self, samples, table, unit = 'gene_callers_id'):
        samples_dict = {}
        for sample in samples:
            samples_dict[sample] = {}
        for v in list(table.values()):
            samples_dict[v['sample_id']][v[unit]] = v['mean_coverage']
        return samples_dict


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Generate a network description file')

    parser.add_argument(*anvio.A('profile-db'), **anvio.K('profile-db'))
    parser.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))

    args = parser.parse_args()

    try:
        network = NetworkDescriptonSamples()
        network.profile_db_path = args.profile_db
        network.contigs_db_path = args.contigs_db
        network.init()
        network.generate_genes_network()
        network.generate_functions_network()
    except ConfigError as e:
        print(e)
        sys.exit(-1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(-2)
