#!/usr/bin/env python
#
# Agalma - Tools for processing gene sequence data and automating workflows
# Copyright (c) 2012-2017 Brown University. All rights reserved.
#
# This file is part of Agalma.
#
# Agalma is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Agalma is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Agalma.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import ast
import csv
import os
import subprocess
import sys
import tempfile
from collections import defaultdict
from operator import attrgetter
from glob import glob
from agalma import config
from agalma import database
from biolite import catalog
from biolite import diagnostics
from biolite import utils
from biolite import workflows


def read_count(expression_run_id):
	"""
	Parse the total number of reads used to calculate expression out of the
	diagnostics for `expression_run_id`.
	"""
	sql = """
		SELECT value
		FROM   diagnostics
		WHERE  run_id=? AND
		       entity="expression.calculate.rsem-calculate-expression" AND
		       attribute="nreads";"""
	row = database.execute(sql, (expression_run_id,)).fetchone()
	if not row:
		utils.die("no read counts found for expression run", expression_run_id)
	return row[0]


def aggregate_counts(expression_run_ids):
	"""
	Aggregate the expected counts by gene across multiple expression runs.
	"""
	counts = {}
	nruns = len(expression_run_ids)
	for i, expression_run_id in enumerate(expression_run_ids):
		for gene, count in database.select_expression(expression_run_id):
			if gene not in counts:
				counts[gene] = ['"NA"'] * nruns
			counts[gene][i] = "%g" % count
	return counts


def counts_table(id, sequence_run_id, expression_run_ids):

	utils.info(
		"looking up read counts for expression runs",
		str(expression_run_ids))
	read_counts = map(read_count, expression_run_ids)
	print utils.indent(3, '"read_count": [%s],' % ",".join(read_counts))

	utils.info(
		"looking up sequence types for sequence run",
		sequence_run_id)
	version = database.latest_genes_version(id)
	sequences = {}
	sql = """
		SELECT   genes.gene,
		         models.id,
		         models.genome_type, 
		         models.molecule_type,
		         models.blast_title,
		         max(models.confidence)
		FROM     agalma_genes AS genes, agalma_models AS models
		  ON     genes.model_id=models.id
		WHERE    models.run_id=? AND genes.version=?
		GROUP BY genes.gene;"""
	for row in database.execute(sql, (sequence_run_id, version)):
		sequences[row[0]] = row

	utils.info(
		"looking up expected counts for expression runs", 
		str(expression_run_ids))
	counts = aggregate_counts(expression_run_ids)
	genes = sorted(counts)

	print utils.indent(3, '"gene": ["%s"],' % '","'.join(genes))
	print utils.indent(3, '"sequence_id": [%s],' % ','.join(
			str(sequences[gene]["id"]) for gene in genes))
	print utils.indent(3, '"genome_type": ["%s"],' % '","'.join(
			sequences[gene]["genome_type"] for gene in genes))
	print utils.indent(3, '"molecule_type": ["%s"],' % '","'.join(
			sequences[gene]["molecule_type"] for gene in genes))
	print utils.indent(3, '"blast_title": ["%s"],' % '","'.join(
			utils.none_to_empty(sequences[gene]["blast_title"]) for gene in genes))
	print utils.indent(3, '"confidence": [%s],' % ','.join(
			"%g" % sequences[gene]["max(models.confidence)"] for gene in genes))
	print utils.indent(3, '"count": [[%s]]' % '],['.join(
			','.join(counts[gene]) for gene in genes))


def expression(id, sequence_run_ids):
	"""
	Export expression levels into JSON format file
	"""

	sequence_run_ids = frozenset(sequence_run_ids)
	taxa = {}

	for sequence_run_id in sequence_run_ids:
		utils.info("looking up sequence run", sequence_run_id)
		run = diagnostics.lookup_run(sequence_run_id)
		assert run.name in ("transcriptome", "assemble", "import")
		taxa[sequence_run_id] = catalog.select(run.catalog_id)
		utils.info(
			"found catalog id", run.id, "(%s)" % taxa[sequence_run_id].species)

	utils.info("looking up expression runs")
	expression_runs = defaultdict(list)
	sql = """
		SELECT runs.id, diagnostics.value
		FROM   runs JOIN diagnostics ON runs.id=diagnostics.run_id
		WHERE  runs.hidden=0 AND runs.done=1 AND runs.name="expression" AND
		       diagnostics.entity="%s" AND
		       diagnostics.attribute="reference_id";
		""" % diagnostics.EXIT
	for row in database.execute(sql):
		sequence_run_id = int(row[1])
		if sequence_run_id in sequence_run_ids:
			run = diagnostics.lookup_run(row[0])
			expression_runs[sequence_run_id].append(run)
			taxa[run.id] = catalog.select(run.catalog_id)

	utils.info(
		"found expression runs:\n",
		'\n '.join(
			"%s (%s): [%s]" % (
				taxa[run_id].id, taxa[run_id].species,
				','.join(map(str, [run.id for run in expression_runs[run_id]])))
			 for run_id in expression_runs))

	print "{"

	for i, sequence_run_id in enumerate(expression_runs, start=1):
		expression_run_ids = map(attrgetter("id"), expression_runs[sequence_run_id])
		records = map(taxa.get, expression_run_ids)
		print utils.indent(2, '"%s": {' % taxa[sequence_run_id].id)
		print utils.indent(3, '"species_name": "%s",' % taxa[sequence_run_id].species)
		for key in ("library_id", "treatment", "individual", "sample_prep"):
			values = map(str, map(attrgetter(key), records))
			if not values:
				utils.die("missing", key, "values for", taxa[sequence_run_id].species)
			print utils.indent(3, '"%s": ["%s"],' % (key, '","'.join(values)))
		counts_table(id, sequence_run_id, expression_run_ids)
		if i == len(expression_runs):
			print utils.indent(2, "}")
		else:
			print utils.indent(2, "},")

	print utils.indent(1, "}")


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="""
		Generates a JSON file containing expression tables, gene trees and a
		species tree for downstream analysis in R.""")
	parser.add_argument("--id", "-i", help="""
		phylogeny catalog ID for looking up latest genes version""")
	parser.add_argument("--speciestree", help="""
		include the specified newick file as the species tree""")
	parser.add_argument("--speciestree_numbered", help="""
		include the specified newick file as the numbered species tree
		from phyldog""")
	parser.add_argument("--genetrees", nargs="+", help="""
		include the specified newick files as the gene trees""")
	parser.add_argument("--sequences", required=True, type=int, nargs="+", metavar="RUN_ID", help="""
		use the specified assemble/import RUN_IDs""")
	args = parser.parse_args()
	print "{"
	if args.speciestree:
		tree = open(args.speciestree).read().strip()
		print utils.indent(1, '"speciestree": "%s",' % tree)
	if args.speciestree_numbered:
		tree = open(args.speciestree_numbered).read().strip()
		print utils.indent(1, '"speciestree_numbered": "%s",' % tree)
	if args.genetrees:
		print utils.indent(1, '"genetrees": ["%s"],' % '","'.join(
			open(tree).read().strip() for tree in args.genetrees))
	print utils.indent(1, '"expression":'),
	expression(args.id, args.sequences)
	print "}"

# vim: noexpandtab sw=4 ts=4
