#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from coretracker.coreutils import CoreFile
from coretracker.coreutils.mtgenes import revmtgenes
from collections import defaultdict
import sqlite3
import argparse
from Bio import SeqIO, Entrez
from Bio.Data import CodonTable
import os
import re
import logging
import operator


def regexp(expr, item):
    """Regexp for search in sqlite database"""
    reg = re.compile(expr)
    return reg.search(item) is not None


def taxonIDLookup(taxonID):
    """Lookup a taxon ID (an integer) in the NCBI taxonomy.
    Returns (Species_name, (taxonomy_genus, taxonomy_family, etc.))
    Will likely throw 'server errors' until intenal timeout is reached if given anything else."""
    finished = 0
    while finished <= maxCheck:
        try:
            handleDownload = Entrez.efetch(
                db="taxonomy", id=taxonID, retmode="xml")
            resultsDownload = Entrez.read(handleDownload)
            handleDownload.close()
            finished = maxCheck + 1
        except:
            if finished == 0:
                print("!!!Server error - retrying...")
                finished += 1
                time.sleep(2)
            elif finished == maxCheck:
                print("!!!!!!Unreachable. Returning nothing.")
                return None
            else:
                finished += 1
                time.sleep(2)
    scientificName = resultsDownload[0]['ScientificName']
    lineage = resultsDownload[0]['Lineage'].split("; ")
    lineage.reverse()
    lineage = tuple(lineage)
    taxId = resultsDownload[0]['TaxId']
    mitoCode = resultsDownload[0]['MitoGeneticCode']['MGCName']
    return (scientificName, lineage, taxId, mitoCode)


def commonLookup(spName):
    """Lookup a species according to its common (English?) name.
    Return first taxid (could be the wrong one)
    """
    finished = 0
    while finished <= maxCheck:
        try:
            handleSearch = Entrez.esearch(db="taxonomy", term=spName)
            resultsSearch = Entrez.read(handleSearch)
            handleSearch.close()
            finished = maxCheck + 1
        except:
            if finished == 0:
                print("!!!Server error checking " + spName + " - retrying...")
                finished += 1
                time.sleep(2)
            elif finished == maxCheck:
                print("!!!!!!Unreachable. Returning nothing.")
                return None
            else:
                finished += 1
                time.sleep(2)
    if resultsSearch['IdList']:
        taxid = resultsSearch['IdList'][0]
        sciname = taxonIDLookup(taxid)
        if sciname:
            return (taxid, sciname)
    return None


def cladeSpecies(cladeName):
    """Finds information about a clade.
    Returns [(species_a, (genus_a, family_a, etc_a)), (species_b, (genus_b, family_b, etc_b), cladeID, mitcohondrial_code]"""
    searchTerm = cladeName + '[subtree] AND species[rank]'
    finished = 0
    output = []
    while finished <= maxCheck:
        try:
            handleSearch = Entrez.esearch(db="taxonomy", term=searchTerm)
            resultsSearch = Entrez.read(handleSearch)
            finished = maxCheck + 1
        except:
            if finished == 0:
                print("!!!Server error checking" +
                      cladeName + " - retrying...")
                finished += 1
                time.sleep(2)
            elif finished == maxCheck:
                print("!!!!!!Unreachable. Returning nothing.")
                return()
            else:
                finished += 1
                time.sleep(2)
    if resultsSearch['IdList']:
        for spId in resultsSearch['IdList']:
            sciname = taxonIDLookup(taxid)
            if sciname:
                output.append(taxid, sciname)
    return output


def isInt(value):
    """ Verify if a string
    can be casted to an int"""
    try:
        int(value)
        return True
    except ValueError:
        return False


def db_lookup(dbname, speclist, complete=True):
    result = []
    specset = set([])
    with sqlite3.connect(dbname) as conn:
        conn.create_function("REGEXP", 2, regexp)
        holder = "|".join(speclist)
        query = "SELECT seqid, organism FROM genbank WHERE organism REGEXP ?"
        if(complete):
            query += " AND complete_genome = 1"
        cursor = conn.execute(query, [holder])
        for row in cursor:
            if row[1] not in specset:
                specset.add(row[1])
                result.append((row[0], row[1]))
    return result


def check_species_list(specielist, namedb, taxid=False, clade=False, complete=True):
    """ Verify and return the true name
    of each specie in your specielist """
    curated_spec_list = []
    if namedb:
        # lookup using database
        curated_spec_list = db_lookup(namedb, specielist, complete)
    else:
        for spec in specielist:
            spechecked = None
            if(taxid and isInt(spec)):
                sciename = taxonIDLookup(spec)[0]
                if sciename:
                    curated_spec_list.append((spec, sciename))
            else:
                # try common lookup
                spechecked = commonLookup(spec)
                if not spechecked and clade:
                    # then clade lookup
                    time.sleep(1)
                    spechecked = cladeSpecies(spec.split()[0])
                if spechecked and isintance(spechecked, list):
                    curated_spec_list.extend(spechecked)
                elif spechecked:
                    curated_spec_list.append(spechecked)
                else:
                    logging.warn("Species not found : %s" % spec)
    return curated_spec_list


def get_poss_event_diff(gene, prot, features, gcode=1):
    """Get correct rna data, based on difference
    between gene and prot and use features to confirm"""

    codontable = CodonTable.unambiguous_dna_by_id[gcode]
    ct_dict = codontable.forward_table

    def get_possible_start(gene, potent_prot):
        # this piece of code find the best start in the gene
        # according to the template prot
        # the gene is expected to be a cds but frameshifting is accepted
        # but not in the first 5 aas
        # check sequence similarity for those 5 first aa
        # until we have a decent match

        max_n = max(len(gene) / 3, 5)
        # check start from 10 first position
        bestmatch = (0, 0)
        for pos in xrange(10):
            matcher = sum([ct_dict[gene[pos + i * 3:pos + (i + 1) * 3].upper()] == potent_prot[i]
                           for i in xrange(max_aa)])
            if matcher == max_x:
                bestmatch = (pos, matcher)
                break
            else:
                bestmatch = max(bestmatch, (pos, matcher),
                                key=operator.itemgetter(1))

        return bestmatch[0]

    best_starting_pos = 0
    if prot:
        best_starting_pos = get_possible_start(gene, prot)

    try:
        potential_prot = gene.translate(gcode)

    except:
        pass


def extract_genome(speclist, genelist, records, revmtgenes, getprot=False, gcode=1):
    """Extract genome content based on a list of species """
    spec_code_map = {}
    gene2spec = defaultdict(list)
    prot2spec = defaultdict(list)

    def motif_in_seq(motif, slist, product):
        return sum([1 for x in slist if (motif in x.lower())]) > 0 and \
            'hypothetical' in " ".join(product).lower()

    # fetch gene and protein from the database and add them to the list
    for (s_id, spec) in speclist:
        g_found_in_spc = {}
        specaff = spec.replace('.', '').replace('-', '_').replace(' ', '_')
        curr_seq = records[s_id]
        for pos, f in enumerate(curr_seq.features):
            # this should filter hypothetical protein that we do not want
            sgene = None
            meet_condition = f.type.upper() == 'CDS' and 'gene' in f.qualifiers.keys(
            ) and not motif_in_seq('orf', f.qualifiers['gene'], f.qualifiers['product'])
            if meet_condition:
                sgene = f.qualifiers['gene'][0].lower()
                if(sgene in revmtgenes.keys()):
                    sgene = revmtgenes[sgene].lower()
                if (sgene in genelist) or (not genelist):
                    seq = None
                    seq = f.extract(curr_seq.seq)
                    if len(seq) % 3 != 0:
                        # print("%s | %d"%(spec, len(seq)%3))
                        try:
                            polyaterm = f.qualifiers['transl_except'][0]
                            pos_range, aa_term = polyaterm.strip(
                                ')').strip('(').split(',')
                            pos_range = pos_range.strip().split(':')[-1]
                            aa_term = aa_term.strip().split(':')[-1]
                            # adding polyA to complete mRNA
                            if 'TERM' in aa_term.upper():
                                # Get the number of A to add
                                n_A = (len(pos_range.split('..')) % 2) + 1
                                seq = seq + Seq('A' * n_A, seq.alphabet)
                                assert len(seq) % 3 == 0
                                logging.info("FIXED : partial termination for the following gene: %s - %s | %s %d A added" %
                                             (sgene, spec, len(seq), n_A))

                        except:
                            logging.warn("Possible frame-shifting in the following gene : %s - %s | %s ==> %d (%d)" %
                                         (sgene, spec, len(seq), len(seq) % 3))
                    elif 'N' in seq:
                        logging.warn("Sequence with undefined nucleotide : %s - %s | %d" %
                                     (sgene, spec, len(seq)))
                    try:
                        table = int(f.qualifiers['transl_table'][0])
                        spec_code_map[spec] = table
                    except:
                        pass
                    rec = SeqRecord(seq, id=specaff, name=specaff)
                    protseq = Seq(f.qualifiers.get("translation", [])[0])
                    if not protseq:
                        protseq = seq.translate(
                            table=spec_code_map.get(spec, gcode))
                    protrec = SeqRecord(protseq, id=specaff, name=specaff)
                    if g_found_in_spc.get(sgene, 0) < 1:
                        # this is to ensure that the same gene is not added
                        # multiple time
                        gene2spec[sgene].append(rec)
                        prot2spec[sgene].append(protrec)
                    g_found_in_spc[sgene] = 1
    return gene2spec, prot2spec, spec_code_map

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Extract genome from genbank file into corefile format.")

    sub_parser = parser.add_subparsers(help="Commands", dest='command')
    build_parser = sub_parser.add_parser(
        'build', help="Build database for fast species searching")
    build_parser.add_argument("--namedb", '-db', dest='database',
                              help="Temporary database for fast name searching. Only sqlite database supported")
    build_parser.add_argument("--indexdb", '-ib', dest='indexdb',
                              required=True, help="Index filename")
    build_parser.add_argument("--genbank", "-G", required=True,
                              dest='genbank', help="Your genbank file")

    extract_parser = sub_parser.add_parser(
        'extract', help="Extract genome from your data")
    extract_parser.add_argument("--genbank", "-G", required=True,
                                dest='genbank', help="Your genbank file")
    extract_parser.add_argument("--indexdb", '-ib', dest='indexdb',
                                required=True, help="Index filename")
    extract_parser.add_argument("--namedb", '-db', dest='database',
                                help="Use build to build a temporary database for fast name searching. Http request will be done if this is not provided")
    extract_parser.add_argument("--taxid", dest='taxid', action='store_true',
                                help="Taxid instead of common name is provided")
    extract_parser.add_argument("--table", dest='gcode', default=1, type=int,
                                help="Genetic code table to use, 1 if not provided")
    extract_parser.add_argument("--complete", dest='complete',
                                action='store_true', help="Complete genome only")
    extract_parser.add_argument("--clade", dest='clade', action='store_true',
                                help="Use this to look for clade name if some of your specs are clades instead")
    extract_parser.add_argument("--genelist", dest='genelist',
                                help="List of genes. If not provided, all coding sequences in the mapping file will be exported")
    extract_parser.add_argument("--speclist", dest='speclist', required=True,
                                help="A file containing a list of species or a specie name")
    extract_parser.add_argument("--outfile", '-o', dest='output',
                                default='output.core', help="Output file name")

    args = parser.parse_args()
    if args.command == 'build':
        # in this case, create the databases
        records = SeqIO.index_db(args.indexdb, args.genbank, format="genbank")
        if args.database:
            schema = '''
                create table IF NOT EXISTS genbank (
                    id integer primary key autoincrement not null,
                    accession text not null unique,
                    seqid text not null unique,
                    gi text,
                    complete_genome bit default 0,
                    source text,
                    organism text,
                    lineage text
                );
                '''

            # this will be re-designed if needed
            with sqlite3.connect(args.database) as conn:
                conn.executescript(schema)
                for k, rec in records.items():
                    annot = rec.annotations
                    desc = rec.description
                    gi = None
                    try:
                        gi = annot.get('gi', annot['accessions'])
                        if isintance(gi, list):
                            gi = gi[0]
                    except:
                        gi = 'null'

                    conn.execute("insert into genbank (accession, seqid, gi, complete_genome, \
                        source, organism, lineage) values (?, ?, ?, ?, ?, ?, ?)", [rec.name, rec.id, gi,
                                                                                   ('complete genome' in desc), annot['source'], annot['organism'], ">".join(annot['taxonomy'])])

                conn.commit()
                print("%s elements inserted in %s" %
                      (conn.total_changes, args.database))
    else:
        # we are actually trying to extract a genome from a list of spec
        records = SeqIO.index_db(args.indexdb, args.genbank, format="genbank")
        speclist = []
        try:
            for line in open(args.speclist):
                line = line.strip()
                if line and not line.startswith('#'):
                    speclist.append(line)
        except IOError:
            speclist.append(args.speclist)

        speclist = check_species_list(set(speclist), args.database,
                                      args.taxid, args.clade, args.complete)
        genelist = []
        if args.genelist:
            with open(args.genelist) as Glist:
                for line in Glist:
                    if not line.startswith('#'):
                        genelist.append(line.strip().lower())
        gene2spec, prot2spec, spec_code = extract_genome(
            speclist, genelist, records, revmtgenes, gcode=args.gcode)
        CoreFile.write_corefile(gene2spec, args.output)
        CoreFile.write_corefile(prot2spec, args.output + "_prot")

        print("\n---------------------\n+++ Genetic code : \n")
        for spec, code in spec_code.items():
            print("%s ==> %s" % (spec, str(code)))
        print('\n---------------------\n+++ Number of genomes with genes : \n')
        for g, spec in gene2spec.items():
            print("%s ==> %d specs" % (g, len(spec)))
