#!/usr/bin/env python


import optparse
import os
import sys

import argweaver

from compbio import phylo, arglib
from rasmus import intervals
from rasmus import treelib
from rasmus import util


o = optparse.OptionParser(
    usage="%prog ARG_FILES ...",
    description="""
Determines consensus local trees from a set of ARG samples. ARGs can be
specified in *.arg or *.smc file formats.
""")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
o.add_option("-n", "--ntrees", default=-1, type="int",
             help="Number of local trees to extract from each ARG")
o.add_option("-b", "--binary", action="store_true",
             help="Ensure each local tree is binary")
o.add_option("-r", "--region",
             help="Extract trees only from region 'start-end'")

conf, args = o.parse_args()

#=============================================================================


def treelen(tree):
    return sum(x.dist for x in tree)


def iter_trees(filename, attr, region=None):
    chrom = "chr"

    for item in argweaver.iter_smc_file(filename, parse_trees=False,
                                        region=region):
        if item["tag"] == "REGION":
            chrom = item["chrom"]
            attr["region"] = [item["chrom"], item["start"]-1, item["end"]]
        elif item["tag"] == "NAMES":
            attr["names"] = item["names"]
        elif item["tag"] == "TREE":
            yield item["start"]-1, item["end"], chrom, item["tree"]


def iter_trees_arg(filename, attr, region=None):
    arg = arglib.read_arg(filename)
    if region:
        arg = arglib.smcify_arg(arg, region[0], region[1])

    attr["names"] = list(arg.leaf_names())
    attr["region"] = ["chr", arg.start, arg.end]

    for block in arglib.iter_recomb_blocks(arg):
        yield (int(block[0]), int(block[1]), "chr",
               (arg, (block[0] + block[1]) / 2.0))


def get_tree(tree_repr):
    if isinstance(tree_repr, basestring):
        return treelib.parse_newick(tree_repr)
    else:
        arg, pos = tree_repr
        names = list(arg.leaf_names())
        tree = arg.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree)
        tree = tree.get_tree()

        for node in list(tree):
            if not node.is_leaf():
                tree.rename(node.name, tree.new_name())
        for i, name in enumerate(names):
            tree.rename(name, i)

        return tree


def cons_smc(trees, attr, positions=None):

    yield {"tag": "NAMES", "names": attr["names"]}
    yield {"tag": "REGION", "chrom": attr["region"][0],
           "start": attr["region"][1]+1, "end": attr["region"][2]}

    i = 0

    for start, end, group in intervals.iter_intersections(trees):
        if positions:
            while i < len(positions)-1 and positions[i] < start:
                i += 1
            if not (start <= positions[i] < end):
                continue
        print >>sys.stderr, start, end

        treeset = map(get_tree, util.cget(group, 3))
        tree = phylo.consensus_majority_rule(treeset, rooted=True)
        treelib.assert_tree(tree)

        if conf.binary:
            phylo.ensure_binary_tree(tree)

        # issue new names for internal nodes to avoid conflict with leaves
        tree.nextname = max(tree.nextname, len(attr["names"]))
        for node in list(tree):
            if not node.is_leaf():
                tree.rename(node.name, tree.new_name())

        yield {"tag": "TREE", "start": start+1, "end": end,
               "tree": tree}


#=============================================================================

if len(args) == 0:
    o.print_help()
    sys.exit(1)

region = None
if conf.region:
    region = map(int, conf.region.split("-"))


# get filenames
filename = args[0]
if "%d" in filename:
    filenames = []
    for i in range(conf.start, conf.end, conf.step):
        fn = filename % i
        if os.path.exists(fn):
            filenames.append(fn)
else:
    filenames = args

attr = {}
trees = []
for filename in filenames:
    print >>sys.stderr, filename
    if filename.endswith(".arg"):
        trees.extend(iter_trees_arg(filename, attr, region))
    else:
        trees.extend(iter_trees(filename, attr, region))
trees.sort()

if conf.ntrees:
    start = attr["region"][1]
    end = attr["region"][2]
    step = int((end - start) / conf.ntrees)
    positions = range(start+step/2, end, step)
    print >>sys.stderr, step, positions
else:
    positions = None

smc = cons_smc(trees, attr, positions=positions)
argweaver.write_smc(sys.stdout, smc)
