#!/usr/bin/python
# coding=utf-8
#
#The MIT License (MIT)
#
#Copyright (c) 2015 Evgeny Taranov
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.

'''
Created on Oct 14, 2014

@author: Evgeny Taranov <allista@gmail.com>
'''

import sys

try:
    from Bio import SeqIO
    from Bio import Entrez
except ImportError as e:
    print str(e)
    print 'BioPython must be installed in your system.'
    sys.exit(1)

import os
import traceback
import logging
import argparse
from time import time, sleep
from datetime import timedelta
from collections import Counter
from itertools import chain
import csv
import re

#------------------------------------------------------------------------------#
class OutIntercepter(object):
    '''A file-like object which intercepts std-out/err'''
    def __init__(self):
        self._oldout = None
        self._olderr = None
    #end def
    
    def write(self, text): pass
    
    def flush(self): pass
    
    def __enter__(self):
        self._oldout = sys.stdout
        self._olderr = sys.stderr
        sys.stdout = sys.stdout = self
        return self
    #end def
    
    def __exit__(self, _type, _value, _traceback):
        if _type is not None and not SystemExit:
            print(_value)
            traceback.print_exception(_type, _value, _traceback, file=self._olderr)
        sys.stdout = self._oldout
        sys.stderr = self._olderr
        return True
    #end def
#end class


class EchoLogger(OutIntercepter):
    '''Wrapper around logging module to capture stdout-err into a log file
    while still print it to std'''

    def __init__(self, name, level=logging.INFO):
        OutIntercepter.__init__(self)
        self._name      = name
        self._log       = name+'.log'
        self._level     = level
        self._logger    = logging.getLogger(name)
        self._handler   = logging.FileHandler(self._log, encoding='UTF-8')
        self._formatter = logging.Formatter('[%(asctime)s] %(message)s')
        #set log level
        self._handler.setLevel(self._level)
        self._logger.setLevel(self._level)
        #assemble pipeline
        self._handler.setFormatter(self._formatter)
        self._logger.addHandler(self._handler)
    #end def
        
    def __del__(self):
        self._handler.close()
        
        
    def __enter__(self):
        OutIntercepter.__enter__(self)
        self._logger.log(self._level, '=== START LOGGING ===')
        return self
    #end def
    
    def __exit__(self, _type, _value, _traceback):
        if _type is not None and not SystemExit:
            print(_value)
            self._logger.error('Exception occured:', 
                               exc_info=(_type, _value, _traceback))
            traceback.print_exception(_type, _value, _traceback, file=self._olderr)
        sys.stdout = self._oldout
        sys.stderr = self._olderr
        self._logger.log(self._level, '=== END LOGGING ===')
        return True
    #end def
        
    def write(self, text):
        log_text = ' '.join(text.split())
        if log_text: self._logger.log(self._level, log_text, exc_info=False)
        self._oldout.write(text)
    #end def
    
    def flush(self): self._oldout.flush()
#end class


def retry(func, error_msg, num_retries):
    for i in xrange(num_retries):
        try:
            result = func()
            break
        except Exception as e:
            print(e)
            if i == num_retries-1:
                raise RuntimeError(error_msg)
    return result
#end def


class BatchEntrez(object):
    #defaults
    RETRIES    = 3
    PAUSE_EACH = 100
    BATCH      = 20
    PAUSE      = 60
    
    REF_DELIMITER = ":   :"
    
    #id types
    NUC     = 'nucleotide'
    PROT    = 'protein'
    WGS     = 'wgs'
#    MGA     = 'mga'
    UPROT   = 'uniprot'
    REFSEQN = 'refseqn'
    REFSEQP = 'refseqp'
    #regexp
    ID_TYPES_RE = {
                   NUC:     re.compile(r'\b([a-zA-Z]{1}\d{5}|[a-zA-Z]{2}\d{6})\b'),
                   PROT:    re.compile(r'\b([a-zA-Z]{3}\d{5})\b'),
                   WGS:     re.compile(r'\b([a-zA-Z]{4}\d{2}\d{6,8})\b'),
#                   MGA:     re.compile(r'\b([a-zA-Z]{5}\d{7})\b'),
                   UPROT:   re.compile(r'\b([OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2})\b'),
                   REFSEQN: re.compile(r'\b([ANX][CGTWSZMR]_\d+)\b'),
                   REFSEQP: re.compile(r'\b([ANYXZW]P_\d+)\b'),
                   }
    ID_TO_DB = {
                NUC:     'nucleotide',
                PROT:    'protein',
                WGS:     'nucleotide',
#                MGA:     '',
                UPROT:   'protein',
                REFSEQN: 'nucleotide',
                REFSEQP: 'protein',
                }
    
    def __init__(self, email):
        self._email      = email
        self._start_time = -1
    #end def
    
    @staticmethod
    def _load_sequences(filename):
        '''generate a SeqRecord object with sequence from a file'''
        #check if filename is an existing file
        records = []
        if os.path.isfile(filename):
            #parse file
            try:
                records = list(SeqIO.parse(filename, 'fasta'))
            except Exception as e:
                print(e._message)
                raise ValueError('load_sequence: unable to parse %s.' % filename)
        else: print('No such file: "%s"' % filename)
        return records
    #end def
    
    @staticmethod
    def _get_isolation_sources(rec):
        source = []
        for f in rec.features:
            if f.type != 'source': continue
            try:
                source.append(f.qualifiers['isolation_source'])
            except KeyError:
                print('%s: the "source" feature does not contain "isolation_source" qualifier\n' % rec.id)
                source.append(['isolation source is not provided'])
            try:
                source.append(f.qualifiers['country'])
            except KeyError:
                source.append(['-'])
            break
        return source
    #end def
    
    @staticmethod
    def _init_job(files):
        jobname = os.path.basename(files[0]).split('.')[0]
        dirname = os.path.dirname(files[0]) or '.'
        return jobname, dirname
    #end def
    
    @classmethod
    def _extract_id(cls, rec):
        for idt, id_re in cls.ID_TYPES_RE.items():
            match = id_re.search(rec.description)
            if match is not None:
                return match.group(0), cls.ID_TO_DB[idt]
        return None, None
    #end def
    
    @classmethod
    def _get_ids_for_db(cls, seq_records, database):
        print('Searching for NCBI IDs in description lines of the provided sequences...')
        ids = []
        for rec in seq_records:
            _id, db = cls._extract_id(rec)
            if _id is None:
                print('No NCBI ID found in the description string:\n%s' % rec.id)
                print('Supported ID formats: %s' % ', '.join(cls.ID_TYPES_RE.keys()))
            elif db != database:
                print('%s ID belongs to the "%s" database and will be ignored.' % (_id, db))
            else: ids.append(_id)
        if not ids:
            print('No NCBI IDs were found in the provided files.\n')
        print('Done.\n')
        return ids
    #end def
    
    def _get_records_for_ids(self, ids, start, stop, database):
        #compose a query
        query = ' or '.join(ids[i] for i in xrange(start, stop))
        #perform the query
        num_ids = len(ids)
        print('[%.1f%%] performing query for accession numbers %d-%d of %d' 
              % (float(stop)/num_ids*100, start+1, stop, num_ids))
        results = retry(lambda : Entrez.read(Entrez.esearch(db=database, term=query, usehistory="y")),
                        'Unable to get Entrez IDs for sequences %d-%d' % (start+1, stop), self.RETRIES)
        if not results['IdList']:
            print('NCBI returned no result for the accession numbers from sequences %d-%d' 
                  % (start+1, stop))
            return []
        webenv    = results['WebEnv']
        query_key = results['QueryKey']
        #fetch genbank data for the received IDs 
        num_results = len(results['IdList'])
        print('Downloading data...')
        data = retry(lambda : Entrez.efetch(db="nucleotide", rettype="gb", retmode="text",
                                        retstart=0, retmax=num_results, complexity=4,
                                        webenv=webenv, query_key=query_key),
                     'Unable to download data for sequences %d-%d' % (start+1, stop), self.RETRIES)
        #parse received data
        try: records = list(SeqIO.parse(data, 'gb'))
        except Exception as e: 
            print(e.message)
            return []
        finally: data.close()
        print('Done. Elapsed time: %s\n' % timedelta(seconds=time()-self._start_time))
        return records
    #end def
    
    def _join(self, *strings):
        out = ''
        for s in strings:
            if not s: continue
            if not out: out = s
            else: out += self.REF_DELIMITER+s
        return out
    #end def
    
    def GetIsolationSources(self, files, database, no_refs = False):
        self._start_time = time()
        #derive job name and get working dir from the first file
        jobname, dirname = self._init_job(files)
        if not os.path.isdir(dirname):
            print('No such directory: '+dirname)
            return 1
        #compose and run the query with logging
        with EchoLogger(os.path.join(dirname, jobname)):
            #load sequences
            seq_records = []
            for filename in files:
                try: seq_records.extend(self._load_sequences(filename))
                except ValueError: continue
            if not seq_records:
                print('No records could be loaded from the provided files.')
                return 2
            #get IDs
            ids = self._get_ids_for_db(seq_records, database)
            if not ids: return 3
            #check number of queries and warn the user
            num_ids     = len(ids)
            num_queries = num_ids/self.BATCH
            num_pauses  = 0; pause_time = 0 
            if num_queries > self.PAUSE_EACH:
                num_pauses = num_queries/self.PAUSE_EACH
                pause_time = num_pauses * self.PAUSE
                self.PAUSE_EACH = num_queries/(num_pauses+1)+1
                print('WARNING: %d separate Entrez queries will be made.\n'
                      'To comply with NCBI rules the queries will be made\n'
                      'in series of %d with %d sec pause in between.\n' 
                      % (num_queries, self.PAUSE_EACH, self.PAUSE))
                print('Total pause time will be:\n%s\n' % timedelta(seconds=pause_time))
            query_time = num_queries * 1/3.0
            if query_time > 5:
                print('No more than 3 requests per second is allowed by NCBI,\n'
                      'so *minimum* time spend for your query will be:\n%s\n' % timedelta(seconds=query_time))
            if pause_time > 0 and query_time > 5:
                print('Total *minimum* estimated time:\n%s\n' % timedelta(seconds=pause_time+query_time))
                print('Note, that depending on the load of NCBI servers it\n'
                      'may take several times as much.\n')
            #setup Entrez engine
            Entrez.email = self._email
            Entrez.tool = 'BatchEntrez.GetIsolationSources'
            #perform queries in batches and write results to a csv file
            unique_sources = Counter()
            all_sources_filename = os.path.join(dirname, jobname+'-sources.csv')
            with open(all_sources_filename, 'wb') as csvfile:
                output = csv.writer(csvfile)
                #write header
                header = ['DESCRIPTION', 'ACCESSION', 'ISOLATION SOURCE', 'COUNTRY']
                if not no_refs: header.append('REFERENCES')
                output.writerow(header)
                #process IDs
                pause_num = self.PAUSE_EACH
                for i in xrange(0, num_ids, self.BATCH):
                    if i/self.BATCH > pause_num:
                        print('Pausing for %d seconds...\n' % self.PAUSE)
                        sleep(self.PAUSE)
                        pause_num += self.PAUSE_EACH
                    try: records = self._get_records_for_ids(ids, i, min(i+self.BATCH, num_ids), database)
                    except RuntimeError as e: 
                        print(e.message)
                        print('Query aborted.')
                        return 4
                    for r in records: 
                        sources = self._get_isolation_sources(r)
                        if sources:
                            row = [r.description, r.id]
                            row.extend(chain(*sources))
                            if not no_refs:
                                row.append('   ###   '.join(self._join(r.title, r.authors, r.journal)
                                           for r in r.annotations['references']))
                            output.writerow(row)
                            for s in sources[0]: unique_sources[s.lower()] += 1
            #write a histogramm of isolation sources in a csv file
            unique_sources_filename = os.path.join(dirname, jobname+'-source-distribution.csv')
            print('All isolation sources with corresponding GenBank IDs were written to:\n   %s\n'
                  'Writing the source distribution histogram to:\n   %s\n' 
                  % (all_sources_filename, unique_sources_filename))
            with open(unique_sources_filename, 'wb') as csvfile:
                output = csv.writer(csvfile)
                map(lambda s: output.writerow(s), unique_sources.most_common())
            print('Done.\nTotal elapsed time: %s\n' % timedelta(seconds=time()-self._start_time))
            return 0
    #end def
#end class


#------------------------------------------------------------------------------#
if __name__ == '__main__':
    #parse command line arguments
    parser = argparse.ArgumentParser(description='Retrieves isolation sources from NCBI '
                                     'given a set of sequences in fasta format with '
                                     'accession numbers specified in sequence description lines.')
    parser.add_argument('-e', '--email', dest='email', metavar='address@domain.com', type=str, nargs=1, required=True,
                        help='Your e-mail is required by NCBI when you use Entrez API.')
    
    parser.add_argument('-d', '--database', dest='database', metavar='database-name', type=str, nargs=1, default='nucleotide',
                        choices=('nucleotide', 'protein'),
                        help='The name of NCBI database to search in. Supported: nucleotide, protein. Default: nucleotide.')
    
    parser.add_argument('-p', '--pause', dest='pause', metavar='sec', type=int, nargs=1, default=[60],
                        help='Pause in seconds in between of series of Entrez queries.')
    
    parser.add_argument('-r', '--retries', dest='retries', metavar='times', type=int, nargs=1, default=[3],
                        help='Number of retries on network failure.')
    
    parser.add_argument('--no-references', dest='no_refs', action='store_true',
                        help='Do not add publication references of each sequence to the output.')
    
    parser.add_argument('files', metavar='file', type=str, nargs='+',
                        help='File(s) in fasta format with sequences and accession numbers.')
    #setup batch fetcher and run the query
    args = parser.parse_args()
    BatchEntrez.RETRIES = args.retries[0]
    BatchEntrez.PAUSE   = args.pause[0]
    entrez = BatchEntrez(args.email)
    sys.exit(entrez.GetIsolationSources(args.files, args.database, args.no_refs))
#end
