#!/usr/bin/env python
#
# Agalma - Tools for processing gene sequence data and automating workflows
# Copyright (c) 2012-2017 Brown University. All rights reserved.
#
# This file is part of Agalma.
#
# Agalma is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Agalma is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Agalma.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import numpy as np
import os
import sys

from collections import defaultdict
from operator import itemgetter

from agalma import config
from biolite import catalog
from biolite import diagnostics
from biolite import report
from biolite import utils

import matplotlib.pyplot as pyplot
from matplotlib.ticker import FixedLocator

def sec_to_hhmmss(sec):
	hh = int(sec / 3600)
	mm = int((sec - hh*3600) / 60)
	ss = int(sec - hh*3600 - mm*60)
	if hh:
		return '%d:%.2d:%.2d' % (hh, mm, ss)
	elif mm:
		return '%d:%.2d' % (mm, ss)
	else:
		return '0:%.2d' % ss


# From ColorBrewer2, Qualitative, Set3
palette = (
	'#8DD3C7',
	'#FFED6F',
	'#BEBADA',
	'#FB8072',
	'#80B1D3',
	'#FDB462',
	'#B3DE69',
	'#FCCDE5',
	'#BC80BD',
	'#CCEBC5')


template = """<html>
<head>
<title>BioLite Resource Profile</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<link href="css/bootstrap.min.css" rel="stylesheet"/>
<style>
@media print {
body {-webkit-print-color-adjust: exact;}
}
.container {
width: 1152px;			/* width of figure */
margin: 30px auto;
}
.table {
font-size: 14px;
font-family: Monaco,Menlo,Consolas,"Courier New",monospace;
}
.table td, .table th {
text-align:right;
}
.table td.left, .table th.left {
text-align:left;
}
</style>
</head>
<body>
<div class="container">
<h2>Resource Usage for <em>%s</em> <small>(Runs %s)</small></h2>
<object data="profile.svg" type="image/svg+xml"></object>
<h4>Calls longer than 1%% of total runtime</h4>
%s
</div>
</body>
</html>
"""

exit_name = "%s.profile" % diagnostics.EXIT


def resource_report(outdir, id, run_ids):
	entities = []
	profiles = []

	if id:
		run_ids = [run.id for run in diagnostics.lookup_runs(id, hidden=False)] + run_ids
	else:
		id = diagnostics.lookup_run(run_ids[0]).catalog_id

	if not run_ids:
		utils.die("no runs found")

	try:
		species = catalog.select(id).species
	except AttributeError:
		species = id

	for run_id in run_ids:
		d = diagnostics.lookup_like(int(run_id), '*.profile')
		if d:
			if exit_name in d: d.pop(exit_name)
			entities += d.keys()
			profiles += d.values()
	size = len(entities)

	names = dict()
	colors = list()
	widths = np.zeros(size)
	mem = np.zeros(size)
	smart_mem = []
	time = np.zeros(size+1)
	cput = np.zeros(size)
	syst = np.zeros(size)
	hwm = np.zeros(size+2)

	colors = [palette[i%len(palette)] for i in range(size)]

	stages = {}
	stage_profiles = defaultdict(list)

	for i, entity in enumerate(entities):
		widths[i] = float(profiles[i].get('walltime', 0.0))
		time[i+1] = time[i] + widths[i]
		stage = '.'.join(entity.split('.')[:2])
		stages[stage] = time[i+1]
		stage_profiles[stage].append(profiles[i]['name'])
		m = profiles[i].get('mem', profiles[i].get('maxrss', 0))
		mem[i] = float(m) / 1048576.0
		smart_mem.append(utils.human_readable_size(m, 1))
		cput[i] = float(profiles[i].get('usertime', 0.0)) / widths[i]
		syst[i] = float(profiles[i].get('systime', 0.0)) / widths[i]
		hwm[i+1] = max(hwm[i], mem[i])

	stages = sorted(stages.iteritems(), key=itemgetter(1))

	# Create table

	top_table = []

	i = 0
	for j, stage in enumerate(stages):
		for name in stage_profiles[stage[0]]:
			walltime = widths[i]
			if walltime > 0.01 * time[-1]:
				colors[i] = palette[len(top_table)%len(palette)]
				top_table.append('<tr><td>%d</td><td class="left"><span style="color:%s !important">&#x25A3;</span> %s.%s</td><td>%s</td><td>%.0f%%</td><td>%0.f%%</td><td>%s</td></tr>' % (j, colors[i], stage[0], name, sec_to_hhmmss(walltime), 100*cput[i], 100*syst[i], smart_mem[i]))
			else:
				colors[i] = '0.5'
			i += 1

	# Create ticks for x axis

	xtick_stages = ([],[])
	xtick_times = ([],[])
	prev_tick = 0
	for i, tick in enumerate(map(itemgetter(1), stages)):
		xtick_stages[0].append(prev_tick + 0.5 * (tick - prev_tick))
		xtick_times[0].append(tick)
		# Don't print labels that are too close together
		if tick - prev_tick < time[-1] * 0.01:
			xtick_stages[1].append('')
			xtick_times[1].append('')
		else:
			xtick_stages[1].append(str(i))
			xtick_times[1].append(sec_to_hhmmss(tick))
		prev_tick = tick

	hwm[-1] = hwm[-2]

	fig = pyplot.figure(figsize=(12,6), dpi=72)

	# CPU plot

	axes = fig.add_subplot(211)

	axes.xaxis.set_label_position('top')
	axes.xaxis.set_ticks_position('top')
	axes.set_xlim((0, time[-1]))
	axes.xaxis.set_major_locator(FixedLocator(xtick_stages[0]))
	axes.set_xticklabels(xtick_stages[1], fontsize=9, family='Arial', weight='bold', color='0.5')
	axes.xaxis.set_minor_locator(FixedLocator(xtick_times[0]))
	axes.tick_params(which='both', top='off', bottom='off')
	for t in axes.xaxis.get_ticklines(): t.set_visible(False)
	axes.grid(axis='x', which='minor')
	axes.set_xlabel('Stage #', fontsize=9, family='Arial', weight='bold', color='0.5')

	ymax = 16.0
	ymin = -0.02 * ymax
	axes.set_ylim((ymin, ymax))
	axes.set_ylabel('Parallelism')

	cput -= ymin
	axes.bar(time[:-1], cput, width=widths, color=colors, linewidth=0, bottom=ymin)
	#axes.bar(time[:-1], syst, width=widths, color='0.5', linewidth=0)
	pyplot.axhline(color='k')

	# Memory plot

	axes = fig.add_subplot(212)

	axes.set_xlim((0, time[-1]))
	axes.set_xticks(xtick_times[0])
	axes.set_xticklabels(xtick_times[1], va='top', rotation=-90, fontsize=8)
	axes.tick_params(direction='out', top='off')
	axes.grid(axis='x')
	axes.set_xlabel('Wall Time (HH:MM:SS)', fontsize=8)

	ymax = 1.05 * hwm[-1]
	ymin = -0.02 * ymax
	axes.set_ylim((ymin, ymax))
	axes.set_ylabel('Peak Memory (GB)')

	mem += -ymin
	axes.bar(time[:-1], mem, width=widths, color=colors, linewidth=0, bottom=ymin)
	axes.plot(time, hwm[1:], 'k--')
	axes.axhline(color='k')

	# Shrink axes by 25% vertically to make room for legend.
	#box = axes.get_position()
	#axes.set_position([box.x0 - 0.1 * box.width, box.y0 + 0.3 * box.height, 1.2 * box.width, 0.8 * box.height])
	#pyplot.legend(patches, labels, ncol=8,
	#	bbox_to_anchor=(0.5, -0.25), loc="upper center", prop=font, fancybox=True)

	pyplot.tight_layout()
	fig.savefig(os.path.join(outdir, 'profile.svg'))

	stage_html = [
		'<table class="table table-striped table-condensed">',
		'<tr><th>#</th><th class="left">Stage / Call</th><th>Runtime</th><th>User CPU%</th><th>System CPU%</th><th>Peak Memory</th></tr>',
		'\n'.join(top_table),
		'</table>']

	open(os.path.join(outdir, 'profile.html'), 'w').write(template % (species, ','.join(utils.number_range(map(int, run_ids))), '\n'.join(stage_html)))
	report.copy_css(outdir)


if __name__ == '__main__':
	parser = argparse.ArgumentParser(description="""
		Generates an HTML report showing resource usage across the specified
		RUN_IDs.""")
	parser.add_argument('--outdir', '-o', default='./',
		type=utils.safe_mkdir, help="""
  		write HTML output to OUTDIR [default: ./]""")
	parser.add_argument('--id', '-i', metavar='CATALOG_ID', help="""
		include all run IDs associated with the BioLite CATALOG_ID""")
	parser.add_argument('run_ids', metavar='RUN_ID', nargs='*', help="""
		include the specified list of run IDs""")
	args = parser.parse_args()

	resource_report(args.outdir, args.id, args.run_ids)

# vim: noexpandtab ts=4 sw=4
