#!/usr/bin/env python

from profileNJ.PolytomySolver import *
from ete3 import PhyloTree
from profileNJ.TreeLib.SupportUtils import read_trees
from profileNJ.TreeLib import TreeUtils, TreeClass, params
import argparse
import sys
import time
import re

GRAPHICAL_ACCESS = True
try:
    from ete3 import TreeStyle, NodeStyle, TextFace, faces, AttrFace
except ImportError, e:
    GRAPHICAL_ACCESS = False
    print("Warning : PyQt not installed")


def prefix_specie_parser(node, speclist):
    matches = [
        sp for sp in speclist if node.name.lower().startswith(sp.lower())]
    if not matches:
        print("Cannot set specie for this node : " +
              node.name + "\nRename its name in the newick file")
    return max(matches, key=len)


def suffix_specie_parser(node, speclist):
    matches = [sp for sp in speclist if node.name.lower().endswith(sp.lower())]
    return max(matches, key=len)


def get_specie_parser(spos):
    return prefix_specie_parser if spos == "prefix" else suffix_specie_parser


def speciemapper(speclist):
    if isinstance(speclist, list):
        return " | ".join([x.lower()[0:2] for x in speclist])
    else:
        return speclist[0:2].lower()


parser = argparse.ArgumentParser(
    description='Tree reconciliation by lcamapping.')

sub_parser = parser.add_subparsers(help="Available option")

# Reconciliation parser
recon_parser = sub_parser.add_parser('run', help="Do Reconciliation")
recon_parser.add_argument(
    '-s', '--sFile', dest='specietree', help="Specie tree in newick format", required=True)
recon_parser.add_argument('-S', '--sMap', dest='smap',
                          help="Gene to species map. Use the standard format.")
recon_parser.add_argument(
    '-g', '--gFile', dest='genetree', help="Gene tree in newick format.", required=True)
recon_parser.add_argument(
    '--sep', dest='gene_sep', help="Gene-Specie separator for each leaf name in the genetree. The program will guess by default. But you should provide it")
recon_parser.add_argument('--spos', dest='spos', default="prefix",
                          help="The position of the specie name according to the separator.")
recon_parser.add_argument('--cap', dest='cap', action='store_true',
                          help="Capitalize the species name of the genetree leaves to match each species. Almost all functions are case sensitive.")
recon_parser.add_argument('--display_losses', dest='losses',
                          action='store_true', help="Display losses in the reconciliation")
recon_parser.add_argument('--output', '--out', dest='output', default="tree",
                          help="Output an image of the reconciliated tree.")
recon_parser.add_argument("--outform", dest="outform",
                          default="pdf", help="Accepted format are svg|pdf|png")
recon_parser.add_argument('--export_orthoxml', dest='orthoxml', action='store_true',
                          help="Export reconciliated tree to export_orthoxml. Losses are not expected")
recon_parser.add_argument('--show_branch_tag', dest='branch_tag',
                          action='store_true', help="Show branch length and support")
recon_parser.add_argument(
    '--verbose', '-v', action='store_true', help="Verbosity")
recon_parser.add_argument(
    '--reroot', '-r', action='store_true', help="Reroot in order to display all root reconciliation result. Not Available in batch mode !!")
recon_parser.add_argument(
    '--debug', action='store_true', help="Debugging ( test, will be removed )")

recon_parser.add_argument(
    '--batch', action='store_true', help="Batch mode, use profileNJ file output")

# Reconciliation parser
smap_parser = sub_parser.add_parser('smap', help="Generate smap")
smap_parser.add_argument(
    '-s', '--sFile', dest='specietree', help="Specie tree in newick format", required=True)
smap_parser.add_argument('--smapfile', dest='smapfile',
                         default="genetospecie.smap", help="Gene to species map output file")

args = parser.parse_args()
specietree = TreeClass(args.specietree)
runmode = 'genetree' in args
if(runmode):

    gnewicks = [args.genetree]
    if(args.batch):
        gnewicks = read_trees(args.genetree)

    elif(args.reroot):
        nt = TreeClass(args.genetree)
        gnewicks = [t.write() for t in nt.edge_reroot(unroot=True)]

    i = 0
    nw_sav = []
    for genewick in gnewicks:
        i += 1
        genetree = TreeClass(genewick, format=1)
        speciemap = {}

        rec_output = args.output + "_%d_reconcilied.nwk" % i
        events_output = args.output + "_%d.events" % i

        if(args.smap):
            regexmap = {}
            with (open(args.smap, 'rU') if isinstance(args.smap, basestring) else args.smap) as INPUT:
                for line in INPUT:
                    g, s = line.strip().split()
                    if ('*') in g and '.*' not in g:
                        g = g.replace('*', '.*')
                    g_regex = re.compile(g, re.IGNORECASE)
                    regexmap[g_regex] = s
            for leaf in genetree:
                for key, value in regexmap.iteritems():
                    if key.match(leaf.name):
                        speciemap[leaf.name] = value

        if args.gene_sep:
            genetree.set_species(
                speciesMap=speciemap, sep=args.gene_sep, capitalize=args.cap, pos=args.spos)
        else:
            # use a function instead
            genetree.set_species(speciesMap=speciemap, sep=args.gene_sep, capitalize=args.cap,
                                 pos=args.spos, use_fn=get_specie_parser(args.spos), speclist=specietree.get_leaf_name())

        mapping = TreeUtils.lcaMapping(genetree, specietree)
        dup, loss = TreeUtils.computeDLScore(genetree)

        TreeUtils.reconcile(genetree, mapping, args.losses, speciemapper)

        # display tree if verbose
        if(args.verbose):
            print(genetree.get_ascii(attributes=[
                  'species', 'name'], show_internal=False))

        # Export tree to extended newick format
        genetree.write(outfile=rec_output, format=0, format_root_node=True)

        # Rendering settings
        if(GRAPHICAL_ACCESS):
            ts = TreeStyle()
            ts.show_leaf_name = True

            if(args.branch_tag):
                ts.show_branch_length = True
                ts.show_branch_support = True

            spec_style = NodeStyle()
            spec_style["shape"] = "circle"
            spec_style["size"] = 10
            spec_style["fgcolor"] = "forestgreen"

            dup_style = NodeStyle()
            dup_style["shape"] = "square"
            dup_style["size"] = 8
            dup_style["fgcolor"] = "darkred"

            loss_style = NodeStyle()
            # Gray dashed branch lines
            loss_style["hz_line_type"] = 1
            loss_style["hz_line_color"] = "#777777"
            loss_style["shape"] = "circle"
            loss_style["fgcolor"] = "#777777"
            spec_style["size"] = 8

            leaf_style = NodeStyle()
            leaf_style["size"] = 0
            leaf_style["vt_line_type"] = 0
            leaf_style["hz_line_type"] = 0
            leaf_style["hz_line_color"] = "#000000"

            # Apply node style
            for n in genetree.traverse():
                if n.is_leaf():
                    n.set_style(leaf_style)
                elif not n.is_leaf():
                    n.set_style(spec_style)
                if(n.type > 0):
                    n.set_style(dup_style)
                elif(n.type == TreeClass.LOST):
                    n.set_style(loss_style)

            def layout(node):
                if node.type == TreeClass.LOST:
                    faces.add_face_to_node(
                        AttrFace("name", fgcolor="#777777"), node, 0)
                elif node.is_leaf():
                    faces.add_face_to_node(AttrFace("name"), node, 0)

            # Save tree as figure
            ts.legend_position = 4
            ts.branch_vertical_margin = 20
            ts.margin_right = 50
            ts.margin_top = 20
            ts.show_leaf_name = False
            ts.layout_fn = layout
            duptxt = TextFace("Duplications : %d" %
                              dup, penwidth=10, fsize=12, ftype="Verdana")
            losstxt = TextFace("Losses : %d" %
                               loss, penwidth=10, fsize=12, ftype="Verdana")
            duptxt.margin_right = 30
            losstxt.margin_right = 30
            duptxt.margin_top = 20
            duptxt.margin_bottom = 10
            ts.legend.add_face(duptxt, column=0)
            ts.legend.add_face(TextFace(""), column=0)
            ts.legend.add_face(losstxt, column=0)
            genetree.render("%s_%d.%s" % (
                args.output, i, args.outform), dpi=400, tree_style=ts)

        all_events = genetree.get_events(include_lost=args.losses)

        with open(events_output, 'w') as OUT:
            OUT.write(
                'DUPLICATIONS : ' + str(dup) + "\tLOSSES : " + str(loss) + "\n")
            for ev in all_events:
                if(ev.etype == 'S'):
                    OUT.write('    ORTHOLOGY RELATIONSHIP: ' + ', '.join(
                        ev.orthologs[0]) + "    <====>    " + ', '.join(ev.orthologs[1]) + "\n")
                elif(ev.etype == 'D'):
                    OUT.write('    PARALOGY RELATIONSHIP: ' + ', '.join(
                        ev.paralogs[0]) + "    <====>    " + ', '.join(ev.paralogs[1]) + "\n")
                elif(ev.etype == 'L'):
                    OUT.write('    LOSSES: ' + ev.losses + "\n")

            header = '>tree %d ; (%d) mcost=%d - dup=%d - lost=%d\n' % (i,
                                                                        i, dup + loss, dup, loss)
            nw_sav.append(header)
            nw_sav.append(genewick + "\n")

        # export tree to orthoXML format
        if(args.orthoxml):
            TreeUtils.exportToOrthoXML(
                genetree, handle=open(args.output + "_%dorthoxml" % i, 'w'))

        ################### TEST  purposes #######################

        if(args.debug):
            def spnaming(node):
                return node.name[:3].lower()

            gtree = PhyloTree(genewick, sp_naming_function=spnaming)

            sptree = PhyloTree(args.specietree)
            for node in sptree.iter_leaves():
                node.name = node.name[:3].lower()

            recon_tree, events = gtree.reconcile(sptree)
            dup2cost = 0
            print(recon_tree)
            print("Orthology and Paralogy relationships:")
            for ev in events:
                if ev.etype == "S":
                    print('ORTHOLOGY RELATIONSHIP:', ','.join(
                        ev.inparalogs), "<====>", ','.join(ev.orthologs))
                elif ev.etype == "D":
                    dup2cost += 1
                    print('PARALOGY RELATIONSHIP:', ','.join(
                        ev.inparalogs), "<====>", ','.join(ev.outparalogs))

            dupcost, lostcost = 0, 0

            for node in genetree.traverse():
                if node.type > 0:
                    dupcost += 1
                elif node.type == TreeClass.LOST:
                    lostcost += 1

            print("Phylonode dup : ", dup2cost, "Reconcile dup: ",
                  dupcost, "Reconcile losses: ", lostcost)

    if args.reroot:
        with open(args.output + ".trees", 'w+') as g_IN:
            g_IN.writelines(nw_sav)

else:
    TreeUtils.generateSmap(specietree, output=args.smapfile)
