#!/usr/bin/env python

import optparse
import os
import sys

import argweaver

from rasmus import intervals
from rasmus import stats
from rasmus import util


o = optparse.OptionParser(
    usage="%prog SMC_FILE ...",
    description="""
Extract the average number of recombination break points in a sliding window.
SMC_FILE can be a pattern where '%d' denotes the sample iteration.
""")
o.add_option("-w", "--window", default=10000, type="int",
             help="Window over which to average number of break points")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
#o.add_option("-r", "--region",
#             help="Only process region 'start-end'")
#o.add_option("-n", "--normalize", action="store_true",
#             help="normalize by generations")

conf, args = o.parse_args()


#=============================================================================

def iter_recombs(filename):
    chrom = "chr"
    for item in argweaver.iter_smc_file(filename):
        if item["tag"] == "REGION":
            chrom = item["chrom"]
        if item["tag"] == "SPR":
            yield item["pos"], chrom


def iter_recombs_window(filename, window):
    recombs = sorted(iter_recombs(filename))
    chrom = recombs[0][1]
    recombs = util.cget(recombs, 0)
    last_x = None
    last_y = None
    for x, y in stats.iter_window(recombs, window, len):
        if last_x:
            yield last_x, x, chrom, last_y
        last_x = x
        last_y = y


#=============================================================================

if len(args) == 0:
    o.print_help()
    sys.exit(1)

filename = args[0]
if "%d" in filename:
    filenames = []
    for i in range(conf.start, conf.end, conf.step):
        fn = filename % i
        if os.path.exists(fn):
            filenames.append(fn)
else:
    filenames = args

recombs = []
for filename in filenames:
    recombs.extend(iter_recombs_window(filename, conf.window))
recombs.sort()

for start, end, group in intervals.iter_intersections(recombs):
    chrom = group[0][2]
    vals = util.cget(group, 3)
    util.print_row(chrom, int(start), int(end), stats.mean(vals),
                   stats.percentile(vals, .025), stats.percentile(vals, .975))
