#!/usr/bin/env python
#
# Agalma - Tools for processing gene sequence data and automating workflows
# Copyright (c) 2012-2014 Brown University. All rights reserved.
#
# This file is part of Agalma.
#
# Agalma is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Agalma is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Agalma.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import ast
import csv
import os
import re
import subprocess
import sys
import tempfile
from collections import defaultdict
from operator import attrgetter
from glob import glob
from agalma import database
import biolite.database
from biolite import catalog
from biolite import diagnostics
from biolite import utils
from biolite import workflows

def speciestree(id, run_id):
	"""
	"""
	if not run_id and id:
		try:
			run_id = diagnostics.lookup_prev_run(id, None, "speciestree").run_id
			utils.info("found speciestree run", run_id)
		except AttributeError:
			utils.info("no speciestree run found for id", id)
	if run_id:
		values = diagnostics.lookup(run_id, diagnostics.EXIT)
		try:
			tree = values["bipartition"]
			print '"%s",' % open(tree).readline().rstrip()
		except KeyError:
			utils.die("couldn't find RAxML bipartition in speciestree run", run_id)
		return tree
	else:
		print "null,"


def annotate_genetrees(raxml_dir, speciestree):
	"""
	Root and reconcile the trees with Notung, then reformat to standard
	newick format.
	"""
	utils.info("finding RAxML bipartitions in", raxml_dir)
	bipartitions = glob(os.path.join(raxml_dir, "RAxML_bipartitions.*"))
	utils.info("found", len(bipartitions))
	with open("trees.txt", 'w') as f:
		print >>f, speciestree
		for tree in bipartitions:
			print >>f, tree
	utils.info("running Notung to root and reconcile genetrees...")
	tmpdir = tempfile.mkdtemp()
	p = subprocess.Popen([
		"notung", "-b", "trees.txt", "--root", "--absfilenames", "--nolosses",
		"--speciestag", "prefix", "--treeoutput", "notung",
		"--silent", "--progressbar", "--outputdir", tmpdir],
		stdout=open(os.path.join(tmpdir, "notung.log"), "w"),
		stderr=subprocess.PIPE)
	for line in p.stderr:
		sys.stderr.write('\r')
		sys.stderr.write(line.rstrip().rpartition('\n')[2])
	sys.stderr.write('\n')
	utils.info("parsing Notung output...")
	r1 = re.compile(r"(n\d+):([\d\.]+)\[&&NHX.*?:D=(\w):B=(\d+)[\d\.]+\]")
	r2 = re.compile(r"(n\d+):([\d\.]+)\[&&NHX.*?:D=(\w)\]")
	r3 = re.compile(r"(n\d+)\[&&NHX.*?:D=(\w)\];$")
	r4 = re.compile(r"\[.+?\]")
	for tree in glob(os.path.join(tmpdir, "*.rooting.0")):
		yield r4.sub("",
			r3.sub(r"\1-\2;",
				r2.sub(r"\1-\3:\2",
					r1.sub(r"\1-\3-\4:\2",
						open(tree).readline().rstrip()))))


def genetrees(id, run_id, speciestree):
	"""
	"""
	if not run_id and id:
		try:
			run_id = diagnostics.lookup_prev_run(id, None, "genetree").run_id
			utils.info("found genetree run", run_id)
		except AttributeError:
			utils.info("no genetree run found for id", id)
	if run_id:
		try:
			raxml_dir = diagnostics.lookup(run_id, diagnostics.EXIT)["raxml_dir"]
		except KeyError:
			utils.die("couldn't find RAxML output for genetree run", run_id)
		print '["%s"],' % '","'.join(annotate_genetrees(raxml_dir, speciestree))
	else:
		print "null,"


def parse_rsem(run_ids):
	"""Parse the quantified genes to populate the fields of the database"""
	genes = {}
	for i, run_id in enumerate(run_ids):
		values = diagnostics.lookup(run_id, diagnostics.EXIT)
		try:
			rsem = values["genes"]
		except KeyError:
			utils.die("could not find RSEM output for expression run", run_id)
		for row in csv.DictReader(open(rsem, 'rb'), delimiter='\t'):
			count = float(row["expected_count"])
			if count > 0.0:
				gene = row["gene_id"]
				if gene not in genes:
					genes[gene] = (["0"] * len(run_ids), ['"NA"'] * len(run_ids))
				genes[gene][0][i] = '%g' % count
				length = float(row["effective_length"])
				if length > 0.0:
					genes[gene][1][i] = '%g' % length
	return genes


def read_count(run_id):
	sql = """
		SELECT value
		FROM   diagnostics
		WHERE  run_id=? AND entity="expression.quantify.rsem-calculate-expression" AND attribute="nreads";"""
	row = biolite.database.execute(sql, (run_id,)).fetchone()
	if not row:
		utils.die("no read counts found for run", run_id)
	return row[0]


def counts_table(load_id, run_ids):

	utils.info("lookuping up read counts for load id", load_id)
	read_counts = map(read_count, run_ids)
	print utils.indent(3, '"read_count": ["%s"],' % '","'.join(read_counts))

	utils.info("looking up sequence types for load id", load_id)
	sequences = {}
	sql = """
		SELECT   gene, sequence_id, genome_type, molecule_type, blast_hit, max(expression)
		FROM     sequences
		WHERE    run_id=?
		GROUP BY gene;"""
	for row in database.execute(sql, (load_id,)):
		sequences[row[0]] = map(str, row[1:])

	genes = parse_rsem(run_ids)

	print utils.indent(3, '"gene": ["%s"],' % '","'.join(
			sequences[gene][0] for gene in genes))
	print utils.indent(3, '"genome_type": ["%s"],' % '","'.join(
			sequences[gene][1] for gene in genes))
	print utils.indent(3, '"molecule_type": ["%s"],' % '","'.join(
			sequences[gene][2] for gene in genes))
	print utils.indent(3, '"blast_hit": ["%s"],' % '","'.join(
			sequences[gene][3] for gene in genes))
	print utils.indent(3, '"count": [[%s]],' % '],['.join(
			','.join(genes[gene][0]) for gene in genes))
	print utils.indent(3, '"length": [[%s]]' % '],['.join(
			','.join(genes[gene][1]) for gene in genes))


def expression(id, homologize_id, load_ids):
	"""
	Export expression levels into JSON format file
	"""

	if not load_ids:
		if not homologize_id and id:
			try:
				homologize_id = diagnostics.lookup_prev_run(id, None, "homologize").run_id
				utils.info("found homologize run", homologize_id)
			except AttributeError:
				utils.info("no homologize run for id", id)
		if homologize_id:
			values = diagnostics.lookup(homologize_id, diagnostics.INIT)
			try:
				load_ids = map(int, ast.literal_eval(values["load_ids"]))
			except KeyError:
				utils.info("no load ids for homologize run", homologize_id)
	if not load_ids:
		utils.die("could not identify load ids")

	load_ids = frozenset(load_ids)
	taxa = {}

	for load_id in load_ids:
		utils.info("looking up load id", load_id)
		run = diagnostics.lookup_run(load_id)
		assert run.name == "load"
		taxa[load_id] = catalog.select(run.id)
		utils.info(
			"found catalog id", run.id, "(%s)" % taxa[load_id].species)

	# Find expression runs that

	utils.info("looking up expression runs")
	tables = defaultdict(list)
	sql = """
		SELECT runs.run_id, diagnostics.value
		FROM   runs JOIN diagnostics ON runs.run_id=diagnostics.run_id
		WHERE  runs.hidden=0 AND runs.done=1 AND runs.name="expression" AND
		       diagnostics.entity="%s" AND
		       diagnostics.attribute="load_id";
		""" % diagnostics.EXIT
	for row in biolite.database.execute(sql):
		load_id = int(row[1])
		if load_id in load_ids:
			run = diagnostics.lookup_run(row[0])
			tables[load_id].append(run)
			taxa[run.run_id] = catalog.select(run.id)

	utils.info(
		"found expression runs:\n",
		'\n '.join(
			"%s (%s): [%s]" % (
				taxa[load_id].id, taxa[load_id].species,
				','.join(map(str, [run.run_id for run in tables[load_id]])))
			 for load_id in tables))

	print "{"

	for i, load_id in enumerate(tables, start=1):
		run_ids = map(attrgetter("run_id"), tables[load_id])
		records = map(taxa.get, run_ids)
		print utils.indent(2, '"%s": {' % taxa[load_id].id)
		print utils.indent(3, '"species_name": "%s",' % taxa[load_id].species)
		for key in ("library_id", "treatment", "individual", "sample_prep"):
			values = map(str, map(attrgetter(key), records))
			if not values:
				utils.die("missing", key, "values for", taxa[load_id].species)
			print utils.indent(3, '"%s": ["%s"],' % (key, '","'.join(values)))
		counts_table(load_id, run_ids)
		if i == len(tables):
			print utils.indent(2, "}")
		else:
			print utils.indent(2, "},")

	print utils.indent(1, "}")


if __name__ == '__main__':
	parser = argparse.ArgumentParser(description="""
		Generates a JSON file containing expression tables, gene trees and a
		species tree for downstream analysis in R.""")
	parser.add_argument('--id', '-i', help="""
		phylogeny catalog ID for looking up latest genetree and homologize
		runs""")
	parser.add_argument('--speciestree', help="""
		used the specified RUN_ID instead of looking up the latest speciestree
		run for the --id""")
	parser.add_argument('--genetree', help="""
		used the specified RUN_ID instead of looking up the latest genetree
		run for the --id""")
	parser.add_argument('--homologize', metavar="RUN_ID", help="""
		used the specified RUN_ID instead of looking up the latest homologize
		run for the --id""")
	parser.add_argument('--load', type=int, nargs='+', metavar="RUN_ID", help="""
		use the specified RUN_IDs instead of looking them up from the loads in
		the latest homologize run""")
	args = parser.parse_args()
	print "{"
	print utils.indent(1, '"speciestree":'),
	tree = speciestree(args.id, args.speciestree)
	print utils.indent(1, '"genetrees":'),
	genetrees(args.id, args.genetree, tree)
	print utils.indent(1, '"expression":'),
	expression(args.id, args.homologize, args.load)
	print "}"

# vim: noexpandtab sw=4 ts=4
