#!/usr/bin/env python

import optparse
import os
import sys

import argweaver

from rasmus import intervals
from rasmus import stats
from rasmus import util


o = optparse.OptionParser(
    usage="%prog SMC_FILE ...",
    description="""
Estimate the TMRCA using an ARG or samples of ARGs.
SMC_FILE can be a pattern, where '%d' denotes the sample iteration.
""")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
conf, args = o.parse_args()


#=============================================================================

def iter_tmrcas(filename):
    chrom = "chr"

    for item in argweaver.iter_smc_file(filename, parse_trees=True,
                                        apply_spr=True):
        if item["tag"] == "REGION":
            chrom = item["chrom"]
        if item["tag"] == "TREE":
            yield (item["start"]-1, item["end"], chrom,
                   item["tree"].root.data["age"])

#=============================================================================

if len(args) == 0:
    o.print_help()
    sys.exit(1)

# get filenames
filename = args[0]
if "%d" in filename:
    filenames = []
    for i in range(conf.start, conf.end, conf.step):
        fn = filename % i
        if os.path.exists(fn):
            filenames.append(fn)
else:
    filenames = args


tmrcas = []
for filename in filenames:
    tmrcas.extend(iter_tmrcas(filename))
tmrcas.sort()

for start, end, group in intervals.iter_intersections(tmrcas):
    chrom = group[0][2]
    vals = util.cget(group, 3)
    util.print_row(chrom, start, end, stats.mean(vals),
                   stats.percentile(vals, .025), stats.percentile(vals, .975))
