#!/usr/bin/env python
"""

Output format
-------------

The output is tab-delimited and contains the following columns:
1. chromosome
2. start coordinate of allele (0-index, end-exclusive)
3. end coordinate of allelel (0-index, end-exclusive)
4. derived allele frequency
5. derived allele count
6. derived allele
7. average age
8. average age lower
9. average age upper
10,12,... lower age in sample
11,13,... upper age in sample

"""


from itertools import izip
import os
import optparse
import random
import sys

import argweaver
import argweaver.smc

from rasmus import stats
from rasmus import treelib
from rasmus import util


o = optparse.OptionParser(
    usage="%prog SMC_FILE SITES_FILE",
    description="""
Estimate allele ages using an ARG or samples of ARGs.  SMC_FILE can
be a pattern, where '%d' denotes the sample iteration.
""")
o.add_option("-s", "--start", default=0, type="int",
             help="Starting sample iteration to use")
o.add_option("-e", "--end", default=5000, type="int",
             help="Last sample iteration to use")
o.add_option("-d", "--step", default=1, type="int",
             help="Step size through samples")
o.add_option("-r", "--region",
             help="Only process region 'start-end'")
o.add_option("", "--sample", action="store_true",
             help="Sample allele age uniformly along branches")
conf, args = o.parse_args()


def iter_allele_ages(smc, sites, region=None, maxage=1000000):
    """
    Iterate over the alleles in an alignment 'sites' and estimate their age.
    """
    if region is None:
        region = [smc.header["start"], smc.header["end"]]
    nseqs = sites.nseqs()
    chrom = smc.header["chrom"]
    names = smc.header["names"]
    for item in smc:
        if item["tag"] == "TREE":
            # Skip tree if not within desired region.
            if not util.overlap(item["start"], item["end"],
                                region[0], region[1]):
                continue

            # Get local tree.
            tree = item["tree"]
            argweaver.smc.rename_tree(tree, names)

            # Iterate through sites contained in local tree.
            for pos, col in sites.iter_region(item["start"], item["end"]):
                # Skip poly-allelic sites.
                if len(set(col)) > 2:
                    continue

                # Determine derived allele.
                derived = sites.get_minor(pos)
                node = treelib.lca([tree[x] for x in derived])

                # If LCA=root, other allele must be derived.
                if node == tree.root:
                    derived = [name for name in names if name not in derived]
                    node = treelib.lca([tree[x] for x in derived])
                if node == tree.root:
                    print >>sys.stderr, (
                        "warning: noncompatiable site %d." % pos)

                # Yield upper and lower bound for allele age.
                yield [chrom, pos-1, pos, len(derived) / float(nseqs),
                       len(node.leaves()),
                       sites.get(pos, derived[0]),
                       node.data["age"], (node.parent.data["age"]
                                          if node.parent else maxage)]

            # Revert tree names.
            for i, name in enumerate(names):
                tree.rename(name, i)


#=============================================================================
# Parse arguments

if len(args) != 2:
    o.print_help()
    sys.exit(1)

smc_file = args[0]
sites_file = args[1]

if "%d" in smc_file:
    # Use pattern to get all SMC files.
    smc_files = []
    for i in range(conf.start, conf.end, conf.step):
        fn = smc_file % i
        if os.path.exists(fn):
            smc_files.append(fn)
else:
    smc_files = [smc_file]


# Get region of interest.
if conf.region:
    region = map(int, conf.region.split("-"))
else:
    region = None

# Read sites and ARGs.
sites = argweaver.read_sites(sites_file, region)
smcs = [argweaver.SMCReader(f, parse_trees=True, apply_spr=True)
        for f in smc_files]

# Gather all ages accross samples.
ages = [iter_allele_ages(smc, sites, sites.region) for smc in smcs]

# Average ages across samples for each allele.
for rows in izip(*ages):
    info = rows[0][:6]
    ages = [row[6:8] for row in rows]
    lows = util.cget(ages, 0)
    tops = util.cget(ages, 1)
    low = stats.mean(lows)
    top = stats.mean(tops)
    avg = stats.mean((a+b)/2.0 for a, b in zip(lows, tops))

    if conf.sample:
        mid = random.uniform(low, top)
        row = info + [mid, low, top]
    else:
        row = info + [avg, low, top] + util.flatten(ages)
    util.print_row(*row)
