#!/usr/bin/env python

import optparse
import os
import sys

import argweaver
from argweaver import argweaverc

from rasmus import intervals
from rasmus import stats
from rasmus import util


o = optparse.OptionParser(
    usage="%prog SMC_FILE ...",
    description="""
Estimate the recombination rate using an ARG or samples of ARGs.
SMC_FILE can be a pattern, where '%d' denotes the sample iteration.
""")
o.add_option("-w", "--window", default=10000, type="int",
             help="Window over which to average number of break points")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
o.add_option("-t", "--ntimes", type="int", default=20,
             help="Number of discretized time points")
o.add_option("", "--maxtime", type="float", default=180e3,
             help="Maximum time point")

conf, args = o.parse_args()

#=============================================================================

if len(args) == 0:
    o.print_help()
    sys.exit(1)

# Get all ARG filenames.
filename = args[0]
if "%d" in filename:
    filenames = []
    for i in range(conf.start, conf.end, conf.step):
        fn = filename % i
        if os.path.exists(fn):
            filenames.append(fn)
else:
    filenames = args

# Get time points.
times = argweaver.get_time_points(ntimes=conf.ntimes, maxtime=conf.maxtime)

#=============================================================================

# Read tree branch lengths for all ARGs.
treelens = []
for filename in filenames:
    print >>sys.stderr, filename
    smc = argweaver.SMCReader(filename)
    chrom, region_start, region_end = (
        smc.header["chrom"], smc.header["start"], smc.header["end"])
    trees = argweaverc.read_local_trees(filename, times, len(times))
    ntrees = argweaverc.get_local_trees_ntrees(trees)
    treelen = [0.0] * ntrees
    blocklens = [0.0] * ntrees
    starts = [0] * ntrees
    ends = [0] * ntrees
    argweaverc.get_treelens(trees, times, len(times), treelen)
    argweaverc.get_local_trees_blocks(trees, starts, ends)
    treelens.append(zip(starts, ends, treelen))
    argweaverc.delete_local_trees(trees)

# Get average tree lengths and recombination positions.
avg_treelens = [[s, e, stats.mean(util.cget(group, 2))]
                for s, e, group in intervals.iter_intersections(
                    sorted(util.concat(*treelens)))]
recombs = util.hist_dict(util.concat(*[[t[0] for t in _treelen]
                                       for _treelen in treelens]))

# Calcuate recombination rate.
recomb_rate = []
i = 0
for p in range(region_start, region_end, 1):
    while i < len(avg_treelens) and avg_treelens[i][1] < p:
        i += 1
    if (i < len(avg_treelens) and
            avg_treelens[i][0] <= p <= avg_treelens[i][1]):
        recomb_rate.append([p, p+1,
                            recombs.get(p, 0) / avg_treelens[i][2] /
                            float(len(treelens))])
        i += 1


starts = util.cget(recomb_rate, 0)
for lowi, highi, low, high in stats.iter_window_index(starts, conf.window):
    if highi > lowi:
        util.print_row(chrom, low, high,
                       stats.mean(util.cget(recomb_rate[lowi:highi], 2)))
