#!/usr/bin/env python

import optparse
import os
import sys

import argweaver

from rasmus import intervals
from rasmus import stats
from rasmus import util


o = optparse.OptionParser(
    usage="%prog SMC_FILE ...",
    description="""
Extract the local effective population size.
SMC_FILE can be a pattern where '%d' denotes the sample iteration.
""")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
conf, args = o.parse_args()


#=============================================================================

def mle_popsize_coal_times(k, times, mintime=0):
    s = 0
    i = k
    last = 0
    for t in times:
        s += i*(i-1) * max(t - last, mintime)
        i -= 1
        last = t
    return s / float(4 * k - 4)


def mle_popsize_tree(tree, mintime=0):
    times = sorted([x.data["age"] for x in tree if not x.is_leaf()])
    return mle_popsize_coal_times(len(tree.leaves()), times, mintime)


def iter_popsize(filename, mintime=0):
    chrom = "chr"

    for item in argweaver.iter_smc_file(filename, parse_trees=True,
                                        apply_spr=True):
        if item["tag"] == "REGION":
            chrom = item["chrom"]
        if item["tag"] == "TREE":
            tree = item["tree"]
            popsize = mle_popsize_tree(tree, mintime)
            yield item["start"], item["end"], chrom, popsize


#=============================================================================

if len(args) == 0:
    o.print_help()
    sys.exit(1)

filename = args[0]

if "%d" in filename:
    filenames = []
    for i in range(conf.start, conf.end, conf.step):
        fn = filename % i
        if os.path.exists(fn):
            filenames.append(fn)
else:
    filenames = args


mintime = 0

popsizes = []
for filename in filenames:
    popsizes.extend(iter_popsize(filename, mintime))
popsizes.sort()

for start, end, group in intervals.iter_intersections(popsizes):
    chrom = group[0][2]
    vals = util.cget(group, 3)
    util.print_row(chrom, start, end, stats.mean(vals),
                   stats.percentile(vals, .025), stats.percentile(vals, .975))
