#!/usr/bin/env python
import argparse
import linecache
from profileNJ.PolytomySolver import *
from profileNJ.TreeLib import TreeUtils, TreeClass, params
import sys
import time
from operator import itemgetter
from math import log10
from multiprocessing.pool import Pool
from multiprocessing import cpu_count
import hashlib

"""
ProfileNJ

Given a non-binary gene tree, a species tree and a gene distance matrix, ProfileNJ outputs all possible
binarization of the gene tree that minimizes the duplication+loss score. The algorithm use clustering algorithm
like NJ and UPGMA distance criterion to join the 'nearest' subtrees first. If the tree is treated as unrooted (using the -r argument),
the program tries every possible root (this takes a while) in order to find the one that yields the lowest
DL-score after correction, or to output every rooted correction.

The output file contains the newick tree separated by a new line for all the solutions.

The output format is the following
> Rooted tree 1 ; cost =
newick solution1
newick solution2

> Rooted tree 2 ; cost =
newick solution1
newick solution2

"""

PROCESSES_CHOOSER = 1.2
try:
    CPU_COUNT = cpu_count()
except (NotImplementedError):
    CPU_COUNT = 1


class SmartFormatter(argparse.ArgumentDefaultsHelpFormatter):

    def _split_lines(self, text, width):
        # this is the RawTextHelpFormatter._split_lines
        if text.startswith('C|'):
            return text[2:].splitlines()
        return argparse.ArgumentDefaultsHelpFormatter._split_lines(self, text, width)


class Output(object):

    def __init__(self, file=None):
        if(file):
            out = open(file, 'w')
            self.out = out
        else:
            self.out = sys.stdout

    def write(self, line):
        self.out.write('%s\n' % line)

    def close(self):
        if self.out is not sys.stdout:
            self.out.close()

    @staticmethod
    def error(message):
        sys.stderr.write("Error: %s\n" % message)
        sys.exit(1)


def parallelize(parallele, count):
    return int(round(CPU_COUNT + (PROCESSES_CHOOSER * log10(count)))) if parallelize else 1

reroot_option = ['none', 'all', 'best']
parser = argparse.ArgumentParser(
    description='Polytomy solver with multiple solutions.', formatter_class=SmartFormatter)
parser.add_argument('-s', '--sFile', type=argparse.FileType('r'), dest='specietree',
                    help="Name of the file containing the species newick tree.", required=True)
parser.add_argument('-S', '--sMap', type=argparse.FileType('r'),
                    dest='smap', help="Gene to species map. Use the standard format.")
parser.add_argument('-g', '--gFile', type=argparse.FileType('r'), dest='genetree',
                    help="Name of the file containing the gene newick tree.", required=True)
parser.add_argument('-d', '--dist', type=argparse.FileType('r'), dest='distfile',
                    help="Name of the file containing the distances between each pair of genes (The gene set should be the same for the leaf set of the genetree).")
parser.add_argument('-o', '--output', dest='outfile',
                    help="Name of your output files with the corrected tree. When batch is specified, each corrected genetree will be printed in the appropriate output file. The genetree is printed on stdout if omitted.")
parser.add_argument('-gl', '--gLine', type=int, dest='gline',
                    help="Index of the line in the gene tree file that corresponds to the current gene tree, starting by 1. When the flag batch is used, each line of gFile will be used as genetree", default=1)
parser.add_argument('-r', '--reroot', choices=reroot_option, dest='reroot', default='none',
                    help='''C|Enable/Disable root mode.\n\tnone: disable reroot mode, correct the input polytomies and return the result.\n\tall: enable reroot mode, reroot the genetree at each node and return all polytomy corrected version for each rooted tree.\n\tbest: enable reroot mode, rerrot the genetree at each node and return all polytomy corrected version for the rooted tree with the smallest Dup-Lost score (First one if not unique).\n\n''')
parser.add_argument('-n', '-nf', '--nnflag', action='store_true',
                    dest='nflag', help="Treat negative distances as large distances.")
parser.add_argument('--sep', dest='gene_sep',
                    help="Gene-Specie separator for each leaf name in the genetree. PolytomySolver will guess by default in a very limited list of special character. ***';;' is not a valid separator for the newick format! IF YOUR SEPARATOR IS \";;\", DON'T USE THIS FLAG. THE PROGRAM WILL AUTOMATICALLY GUESS. ***")
parser.add_argument('--spos', dest='spos', default="prefix", choices=("prefix", "postfix"),
                    help="The position of the specie name according to the separator. Supported option are prefix and postfix")
parser.add_argument('--nflagval', type=float, dest='mval', default=1e305,
                    help="Set largest value in the distance matrix. Entries on the main diagonal and negative values will be replaced by mValue.")
parser.add_argument('-c', '--cluster', choices=['nj', 'upgma', 'rand'], default='nj',
                    help="C|Set the clustering methods.\n\tupgma: UPGMA (Unweighted Pair Group Method with Arithmetic Mean) clustering algo.\n\tnj: neighbor joining clustering method, (slower).\n\trand: A random clustering method (should be faster than upgma).\n\n")
parser.add_argument('--slimit', type=int, dest="sol_limit", default=30,
                    help="Set the max number of solution per genetree. Possible values are -1 to return all the solution or n, n>0 for a specific number of solution. Setting this argument to -1 is computationally expensive.")
parser.add_argument('--plimit', type=int, default=-1, dest="path_limit",
                    help="Set the max number of solution for each polytomy in the genetree. Possible values are -1 to explore all the solution or n, n>0 for a specific number of solution. Setting this argument to -1 is also computationally expensive.")
parser.add_argument(
    '-v', action='store_true', dest='verbose', help=" Output verbosity")
parser.add_argument(
    '--try_hard', action='store_true', dest='tryhard', help="Try correcting errors due to common mistake before raising an exception")
parser.add_argument('--batch', action='store_true', dest='batch',
                    help=" Use this flag to enable batch mode. In batch mode, gLine value is discarded, --dist should be a file whose line link to the distance matrix of the genetree at the same line number in your genetree file")
parser.add_argument('--seuil', type=float, dest="seuil",
                    help="Branch contraction threshold, when the tree is binary. Use only when the tree is binary.")
parser.add_argument('--cap', dest='cap', action='store_true',
                    help="Capitalize the species name of the genetree leaves to match each species. Almost all functions are case sensitive.")
parser.add_argument('--parallelize', dest='parallele',
                    action='store_true', help="Use parallisation")
parser.add_argument('--firstbest', dest='firstbest', action='store_true',
                    help="Only output solution for the first root with the best dl score.")

cost_group = parser.add_mutually_exclusive_group(required=False)
cost_group.add_argument('--cost', type=float, nargs=2, dest='costdl',
                        help="D L : 2 float values, duplication and loss cost in this order")
cost_group.add_argument('--scost', dest='sdlcost', type=argparse.FileType('r'),
                        help="A file which contain the DL cost for each species. If a specie is not found in this list, its DL cost will be set to default value (1)")

parser.add_argument('--internalcost', dest='idlcost', choices=['mean', 'default'], default='default',
                    help="internal Node duplication and loss cost. Set this value to mean in order to take the mean duplication and loss from the leaf; or set it to default (1) or  value obtained from `cost`.")

args = parser.parse_args()

# Get list of species
sp = TreeClass(args.specietree.name)
spleave = sp.get_leaf_names()

# dup and loss cost
defdup, defloss = 1, 1
losscost, dupcost = {}, {}

# if user prefer a 2 input cost
if(args.costdl):
    defdup = args.costdl[0]
    defloss = args.costdl[1]

# a dlcost file is provided
if(args.sdlcost):
    for dline in args.sdlcost:
        try:
            sp, spdup, sploss = dline.strip().split()
            sp = hashlib.sha384(",".join(sorted(sp.split(",")))).hexdigest()
            dupcost[sp] = float(spdup)
            losscost[sp] = float(sploss)
        except ValueError:
            raise ValueError(
                "Check your cost file. Leaf under an internal node should only be separated by ','. No space tolerance")
        except Exception:
            pass

params.set(dupcost, losscost, (defdup, defloss), args.idlcost)

if(args.gline < 0):
    raise Exception("gLine must be > 0")

dists = None
infer_dist_from_br = True
if(args.batch):
    gtrees = [line.strip() for line in args.genetree.readlines()]
    if args.distfile:
        dists = [line.strip() for line in args.distfile.readlines()]
        infer_dist_from_br = False
else:
    gtrees = [linecache.getline(args.genetree.name, args.gline)]
    if args.distfile:
        dists = [args.distfile.name]
        infer_dist_from_br = False


gtree_number = 0
for gtree in gtrees:
    outlog = Output("%s%s" % (args.outfile, (gtree_number + 1 if len(gtrees)
                                             > 1 else ""))) if args.outfile is not None else Output()
    start_time = time.time()

    if not infer_dist_from_br:
        cur_dist = dists[gtree_number]
    else:
        cur_dist = True
        # bad idea but used to decrease the number of arguments

    oritree, specietree, distance_matrix, node_order = TreeUtils.polySolverPreprocessing(
        gtree, args.specietree.name, cur_dist, specie_pos=args.spos, capitalize=args.cap, gene_sep=args.gene_sep, nFlag=args.nflag, smap=(args.smap.name if args.batch else args.smap), errorproof=args.tryhard)
    tree_list = [oritree]

    bestroot_para = parallelize(args.parallele, len(tree_list))

    if (args.seuil):
        # if we want a contraction, do it before
        for tree in tree_list:
            tree.contract_tree(seuil=args.seuil)

    if args.reroot.lower() == 'all':
        tree_list.extend(oritree.reroot())

    elif args.reroot.lower() == 'best':
        tree_list.extend(oritree.reroot())
        dl_costs = []

        for genetree in tree_list:
            cost = Multipolysolver.computePolytomyReconCost(
                genetree, specietree, verbose=False)
            dl_costs.append(cost)
            # print("COST=%.9f"%(cost))

        # We find the list of tree with lowest dup score
        best_dl = min(dl_costs)
        # stupid floats
        # (this should output all the solution with the minimum reconcilliation cost
        tree_list = itemgetter(
            *[x for x in xrange(len(dl_costs)) if abs(dl_costs[x] - best_dl) < 0.0000001])(tree_list)

        if(type(tree_list) != tuple):
            tree_list = [tree_list]

        if (args.firstbest):
            tree_list = [tree_list[0]]

    count = 0

    if(args.parallele):
        # parallelized version
        # The number of process should be a function of the number of tree to
        # solve
        multipoly_para = len(tree_list) if len(
            tree_list) <= CPU_COUNT else parallelize(args.parallele, len(tree_list))
        pool = Pool(processes=multipoly_para)
        results = []
        isexecuted = [True] * len(tree_list)
        for genetree in tree_list:

            if genetree.has_polytomies():
                arguments = (genetree, specietree, distance_matrix, node_order,
                             args.verbose, args.path_limit, args.cluster, args.sol_limit)
                async_result = pool.apply_async(solvePolytomy, arguments)
                results.append(async_result)

            else:
                isexecuted[count] = False
                results.append(None)
            count += 1

        for i in range(len(results)):

            polysolution = results[i].get() if isexecuted[
                i] else [tree_list[i]]

            # Copy, in order to not change the solution newick for export
            f_tree = polysolution[0].copy(method='simplecopy')
            lcamap = TreeUtils.lcaMapping(f_tree, specietree)
            dl_cost = TreeUtils.computeDLScore(f_tree)
            dl_count = TreeUtils.computeDL(f_tree)
            if args.verbose and infer_dist_from_br:
                print("Distance matrix used: \n")
                print("  ".join(node_order))
                for x in distance_matrix:
                    print("[" + " ".join(["%.3f" % y for y in x]) + "]\n")
                print("--------------------------------------\n")
            outlog.write('>Tree %s; dup=%s loss=%s m_cost=%s' %
                         (i + 1, dl_count[0], dl_count[1], sum(dl_cost)))
            for tree in polysolution:
                outlog.write(tree.write(format=9))  # , features=["species"]))

        pool.close()
        pool.join()

    else:
        # Not paralellized mode
        for genetree in tree_list:
            first = True
            count += 1

            if genetree.has_polytomies():
                polysolution = solvePolytomy(genetree, specietree, distance_matrix, node_order,
                                             verbose=args.verbose, sol_limit=args.sol_limit, method=args.cluster, path_limit=args.path_limit)
            else:
                polysolution = [genetree]

            # Copy, in order to not change the solution newick for export
            f_tree = polysolution[0].copy(method='simplecopy')
            lcamap = TreeUtils.lcaMapping(f_tree, specietree)
            dl_cost = TreeUtils.computeDLScore(f_tree)
            dl_count = TreeUtils.computeDL(f_tree)
            if args.verbose and infer_dist_from_br:
                print("Distance matrix used: \n")
                print("  ".join(node_order))
                for x in distance_matrix:
                    print("[" + " ".join(["%.3f" % y for y in x]) + "]\n")
                print("--------------------------------------\n")

            outlog.write('>Tree %s; dup=%s loss=%s m_cost=%s' %
                         (count, dl_count[0], dl_count[1], sum(dl_cost)))
            for tree in polysolution:
                outlog.write(tree.write(format=9))  # , features=["species"]))

    end_time = time.time()
    gtree_number += 1

    outlog.close()
    para_mode = ""
    if(args.parallele):
        para_mode = "with parallelization (CPU_COUNT=%s)" % (CPU_COUNT)
    print("\nEND  profileNJ on file: '%s', line %s in %f s, %s" %
          (args.genetree.name, gtree_number, (-start_time + end_time), para_mode))

linecache.clearcache()
