#!/usr/bin/env python
import argparse
import re
import sys
from profileNJ.TreeLib import *
from profileNJ.PolytomySolver import *
from profileNJ.TreeLib.SupportUtils import timeit
from sys import stdout
import time

sys.setrecursionlimit(1000000)


class GTreeSolver():

    def __init__(self, genetree, specietree, mode="psolver", dupcost=1, losscost=1):

        self.genetree = genetree
        self.specietree = specietree
        self.mode = mode
        self.dupcost = dupcost
        self.losscost = losscost
        if mode == 'linzz':
            specietree.label_internal_node()
            self.lcamap = TreeUtils.lcaMapping(genetree, specietree, False)
            self.solver = SingleSolver.LinPolySolver(
                genetree, specietree, self.lcamap)

        elif mode == 'dynzz':
            specietree.label_internal_node()
            self.lcamap = TreeUtils.lcaMapping(genetree, specietree, False)
            self.solver = SingleSolver.DynPolySolver(
                genetree, specietree, self.lcamap, self.dupcost, self.losscost)

        elif mode == "psolver":
            self.lcamap = TreeUtils.lcaMapping(genetree, specietree, False)
            self.solver = PolySolver.GeneTreeSolver(
                genetree, specietree, self.lcamap, self.dupcost, self.losscost)
            self.solver.labelInternalNodes(genetree)
            self.solver.labelInternalNodes(specietree)
            self.solver.use_dp = False
            if dupcost != losscost:
                self.solver.use_dp = True

        elif mode == "notung":
            specietree.label_internal_node()
            self.lcamap = TreeUtils.lcaMapping(genetree, specietree, False)
            self.solver = SingleSolver.NotungSolver(
                genetree, specietree, self.lcamap, self.losscost, self.dupcost)

        else:
            specietree.label_internal_node()
            self.lcamap = TreeUtils.lcaMapping(genetree, specietree, False)
            self.solver = SingleSolver.Dynamiq2(
                genetree, specietree, self.lcamap, self.dupcost, self.losscost)

    # to consider everything fair, this is the method
    # I think we should time to get running time for each algorithm
    @timeit
    def solvePolytomies(self, nsol):
        if self.mode == 'psolver':
            return [x + ";" for x in self.solver.solvePolytomies(nsol)]
        else:
            # current notung DP implementation does not enable enable
            # multiple solution.
            return [self.solver.reconstruct()]


parser = argparse.ArgumentParser(
    description='PolySolver : A tool to resolve polytomies in a genetree')
parser.add_argument('-s', '--spectree', dest='specnw',
                    help="Name of the file containing the species newick tree.", required=True)
parser.add_argument('-S', '--sMap', dest='smap',
                    help="Gene to species map. Use the standard format.")
parser.add_argument('-g', '--genetree', dest='genenw',
                    help="Name of the file containing the gene newick tree.", required=True)
parser.add_argument('--sep', dest='gene_sep',
                    help="Specify a gene separator if you're are not using a smap")
parser.add_argument('--losscost', type=float, default=1,
                    dest='losscost', help="Specify the losses cost")
parser.add_argument('--dupcost', type=float, default=1,
                    dest='dupcost', help="Specify the duplication cost")
parser.add_argument('--nsol', type=int, default=1,
                    dest='nsol', help="Number of solution to output")
parser.add_argument('--spos', dest='spos', default='prefix',
                    help="Gene position when you have specified a separator. Default value is prefix")
parser.add_argument('-o', '--output', dest='outfile',
                    help="Name of your output files with the corrected tree. The resolutions are printed on stdout if omitted.")
parser.add_argument('--mode', dest='mode', default="psolver", choices=[
                    'psolver', 'linzz', 'dynzz', 'notung', 'dynzz2'], help="Algorithm to use.")
parser.add_argument('--showcost', dest='showcost', action='store_true',
                    help="Use this to show the reconciliated cost at the end. By default, only the resolved tree is shown")
args = parser.parse_args()

genetree = TreeClass(args.genenw)
specietree = TreeClass(args.specnw)
speciemap = None
if(args.smap):
    speciemap = {}
    regexmap = {}
    with open(args.smap, 'rU') as INPUT:
        for line in INPUT:
            g, s = line.strip().split()
            if ('*') in g and '.*' not in g:
                g = g.replace('*', '.*')
            g_regex = re.compile(g, re.IGNORECASE)
            regexmap[g_regex] = s
    for leaf in genetree:
        for key, value in regexmap.iteritems():
            if key.match(leaf.name):
                speciemap[leaf.name] = value

    genetree.set_species(speciesMap=speciemap)

elif args.gene_sep:
    genetree.set_species(sep=args.gene_sep, pos=args.spos)

else:
    genetree.set_species(use_fn=lambda x: x.name)

gsolver = GTreeSolver(genetree, specietree, args.mode,
                      args.dupcost, args.losscost)

time, solutions = gsolver.solvePolytomies(args.nsol)

with (open(args.outfile, 'w') if args.outfile else stdout) as OUTPUT:
    for sol in solutions:
        OUTPUT.write(sol + "\n")

        if args.showcost:
            t = TreeClass(sol)
            print("Number of leaves", len(t))
            if args.mode in ['dynzz', 'dynzz2', 'linzz']:
                t.set_species(sep=args.gene_sep, pos=args.spos,
                              speciesMap=speciemap)
            else:
                t.set_species(use_fn=lambda x: x.name)
            specietree = TreeClass(args.specnw)
            lcamap = TreeUtils.lcaMapping(t, specietree, False)
            dup, loss = TreeUtils.computeDL(t, lcamap)
            print("Number of duplication : %d\nNumber of loss : %d\nReconciliation cost : %.2f" % (
                dup, loss, dup * args.dupcost + loss * args.losscost))

    print("\n==>%d solutions were obtained in %.3fs using algorithm : %s" %
          (len(solutions), time, args.mode))
