#!/usr/bin/env python

import optparse
import random

import argweaver

from rasmus import util


o = optparse.OptionParser()
o.add_option("-s", "--sites", metavar="SITES_FILE", default="-")
o.add_option("-o", "--output", metavar="OUTPUT_FILE", default="-")
o.add_option("-r", "--region", metavar="START-END")
o.add_option("-n", "--nseqs", metavar="NUMBER_OF_SEQS", type="int",
             help="use a random subsst of sequences")
o.add_option("--both-haps", action="store_true",
             help="use both haplotypes for each individual")
o.add_option("-k", "--keep-invariant", action="store_true",
             help="keep invariant sites")
o.add_option("-i", "--ind", metavar="INDIVIDUALS")

conf, args = o.parse_args()

#=============================================================================


def is_variant(col):
    c = col[0]
    for i in xrange(1, len(col)):
        if col[i] != c:
            return True
    return False


# read data
sites = argweaver.read_sites(conf.sites)


# get individuals
ind = range(sites.nseqs())
if conf.nseqs:
    if conf.both_haps:
        dips = util.unique(x[:-2] for x in sites.names)
        random.shuffle(dips)
        dips = dips[:conf.nseqs / 2]
        names = util.flatten((x + "_1", x + "_2") for x in dips)
        ind = [i for i, name in enumerate(sites.names) if name in names]
    else:
        random.shuffle(ind)
        ind = ind[:conf.nseqs]

elif conf.ind:
    names2 = set(util.read_strings(conf.ind))
    ind = [i for i in ind if sites.names[i] in names2]
    for name in names2:
        assert name in sites.names, name

if conf.region:
    region = map(int, conf.region.split("-"))
else:
    region = sites.region

# make subset
sites2 = argweaver.Sites(names=util.mget(sites.names, ind), chrom=sites.chrom,
                         region=region)

for i, bases in sites:
    if region[0] <= i <= region[1]:
        bases = "".join(util.mget(bases, ind))
        if is_variant(bases) or conf.keep_invariant:
            sites2.append(i, bases)

# write new sites
argweaver.write_sites(conf.output, sites2)
