#!/usr/bin/env python

import optparse
import os
import sys

import argweaver
import argweaver.sim

from compbio import arglib
from rasmus import util


o = optparse.OptionParser(
    description="""
Simulate an ARG and DNA sequences according to an evolutionary
model. The model can be Coalescent with Recombination (coal_recomb),
the Sequentially Markov Coalescent (smc), or the Discretized
Sequentially Markov Coalescent (dsmc).
""")
o.add_option("-k", "--nseqs", type="int",
             help="Number of sequences to simulate")
o.add_option("-N", "--popsize", type="float", default=1e4,
             help="Effective population size (default=1e4)")
o.add_option("-r", "--recombrate", type="float", default=1.5e-8,
             help="Recombination rate (recombination/generation/site)")
o.add_option("-m", "--mutrate", type="float", default=2.5e-8,
             help="Mutation rate (mutations/generation/site)")
o.add_option("-L", "--seqlen", type="int", default=int(200e3),
             help="Sequence length")
o.add_option("-c", "--chrom", default="chr",
             help="Chromosome/sequence name (default='chr')")
o.add_option("-o", "--output", default="out",
             help="Output prefix (default='out')")
o.add_option("", "--ntimes", type="int", default=20,
             help="Number of discretized time point (default=20)")
o.add_option("", "--maxtime", type="float", default=200e3,
             help="Maximum discretized time point (default=200e3)")
o.add_option("", "--model", default="dsmc",
             help="Simulation model: dsmc, smc, coal_recomb (default=dsmc)")
o.add_option("", "--discretize-sites", action="store_true", default=False,
             help="Round recombinations and mutations to nearest integer site")
o.add_option("", "--infsites", action="store_true", default=False,
             help="Use infinite site assumption")
o.add_option("-R", "--recombmap", metavar="<recombination rate map file>",
             help="Recombination map file")
o.add_option("--output-mutations",
             action="store_true",
             help="Output mutations")


if len(sys.argv) == 1:
    o.print_help()
    sys.exit(1)

conf, args = o.parse_args()


def write_mutations(stream, muts):
    for (node, parent, pos, age) in muts:
        util.print_row(node.name, parent.name, pos, age, out=stream)


def discretize_mutations(muts):
    for (node, parent, pos, age) in muts:
        yield (node, parent, int(pos), age)


#=============================================================================

# log model parameters
print "model parameters:"
print "  nseqs =", conf.nseqs
print "  seqlen =", conf.seqlen
print "  mutation rate =", conf.mutrate
print "  recombination rate =", conf.recombrate
print "  population size =", conf.popsize
print

if conf.recombmap:
    recombmap = util.read_delim(conf.recombmap, parse=True)
else:
    recombmap = None

if conf.model == "coal_recomb":
    print "simulating using the Coalescent with Recombination model..."
    arg = arglib.sample_arg(conf.nseqs, 2*conf.popsize, conf.recombrate,
                            start=0, end=conf.seqlen)
    if conf.discretize_sites:
        argweaver.discretize_arg_recomb(arg)
    muts = arglib.sample_arg_mutations(arg, conf.mutrate)
    if conf.discretize_sites:
        muts = list(discretize_mutations(muts))

elif conf.model == "smc":
    print "simulating using the SMC model..."
    arg = arglib.sample_arg_smc(conf.nseqs, 2*conf.popsize, conf.recombrate,
                                start=0, end=conf.seqlen)
    if conf.discretize_sites:
        argweaver.discretize_arg_recomb(arg)
    muts = arglib.sample_arg_mutations(arg, conf.mutrate)
    if conf.discretize_sites:
        muts = list(discretize_mutations(muts))

elif conf.model == "dsmc":
    print "simulating using the DSMC model..."
    times = argweaver.get_time_points(ntimes=conf.ntimes, maxtime=conf.maxtime)
    arg = argweaver.sim.sample_arg_dsmc(
        conf.nseqs, 2*conf.popsize, conf.recombrate,
        recombmap=recombmap,
        start=0, end=conf.seqlen, times=times)
    muts = argweaver.sim.sample_arg_mutations(arg, conf.mutrate, times=times)

else:
    print >>sys.stderr, "unknown simulation model:", conf.model

# simulate sequence
seqs = argweaver.sim.make_alignment(arg, muts, infsites=conf.infsites)
sites = argweaver.seqs2sites(seqs)
sites.chrom = conf.chrom

# write output
path = os.path.dirname(conf.output)
if path:
    util.makedirs(path)

print "writing output: ", conf.output + ".sites"
sites.write(conf.output + ".sites")

print "writing output: ", conf.output + ".arg"
arglib.write_arg(conf.output + ".arg", arg)

if conf.output_mutations:
    print "writing output: ", conf.output + ".mut"
    with open(conf.output + ".mut", "w") as out:
        write_mutations(out, muts)
