#!/usr/bin/env python
#
# Copyright John Reid 2012, 2013, 2014
#

"""
Analyses the spacing between motif occurrences.
"""

import logging
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()

from optparse import OptionParser
from collections import defaultdict
from cookbook.pylab_utils import pylab_context_ioff
import os
import pyicl
import numpy
import pylab
import cPickle
from stempy.spacing import plot_spacings, calc_llr_statistic, \
    count_all_pairs, add_max_distance_option
from stempy import ensure_dir_exists, html_copy_static
from stempy.scan import load_occurrences


parser = OptionParser()
add_max_distance_option(parser)
parser.add_option(
    "-r",
    "--results-dir",
    default='.',
    help="Look for results from PWM scan in DIR",
    metavar="DIR"
)
parser.add_option(
    "-t",
    "--log-evidence-plot-threshold",
    type=float,
    default=1.,
    help="Plot those spacings for which the log evidence in favour of a non-uniform distribution is at least THRESHOLD",
    metavar="THRESHOLD"
)
parser.add_option(
    "--ignore-homodimers",
    action='store_true',
    help="Ignore spacings between instances of the same motif."
)
parser.add_option(
    "--report-inverse",
    action='store_true',
    help="Report separate spacing analysis for (MOTIF1, MOTIF2) and (MOTIF2, MOTIF1)"
)
options, args = parser.parse_args()
occurrences_filename = os.path.join(options.results_dir, 'steme-pwm-scan.out')
seq_lengths_filename = os.path.join(
    options.results_dir, 'steme-pwm-scan.lengths')


#
# Now we know where we will write output to, start a log file
#
output_dir = os.path.join(options.results_dir, 'spacings', 'max-%d' % options.max_distance)
rel_output_dir = os.path.join('spacings', 'max-%d' % options.max_distance)
ensure_dir_exists(output_dir)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
log_file_handler = logging.FileHandler(
    os.path.join(output_dir, 'steme-spacing-max-%d.log' % options.max_distance), mode='w')
log_file_handler.setFormatter(formatter)
log_file_handler.setLevel(logging.INFO)
logging.getLogger('').addHandler(log_file_handler)


#
# Load the occurrences and associated sequence lengths,
# they will come sorted by position
#
occurrences, seq_infos, motifs = load_occurrences(options)


#
# Iterate through the occurrences counting spacings
#
logger.info(
    'Examining spacings of up to %d b.p. between %d occurrences of %d motifs in %d sequences',
    options.max_distance, len(occurrences), len(motifs), len(seq_infos)
)
spacings = count_all_pairs(
    occurrences, seq_infos, ignore_close_to_end=True, options=options)


spacings_pickle_filename = os.path.join(
    output_dir, 'spacings-max-%d.pickle' % options.max_distance)
logger.info('Pickling spacings to: %s', spacings_pickle_filename)
cPickle.dump(dict(spacings), open(spacings_pickle_filename, 'w'))


#
# Examine log evidence against uniform distribution of counts
#
log_evidences = [
    (float(calc_llr_statistic(spacing_counts, numpy.ones_like(spacing_counts))), (motif1, motif2))
    for (motif1, motif2), spacing_counts in spacings.iteritems()
    if spacing_counts.sum() > 0
]
log_evidences.sort(reverse=True)  # Sort by strength of evidence


#
# Plot spacings
#
already_examined = dict()
with pylab_context_ioff():
    for rank, (log_evidence, (motif1, motif2)) in enumerate(log_evidences):
        if not options.report_inverse and (motif2, motif1) in already_examined:
            continue
        spacing_counts = spacings[(motif1, motif2)]
        rel_filename = ''
        if log_evidence >= options.log_evidence_plot_threshold:
            tag = '%03d-%s-%s' % (rank, motif1, motif2)
            plot_spacings(options.max_distance, spacing_counts)
            pylab.figtext(.25, .97, 'primary: ' + motif1, ha='center')
            pylab.figtext(.25, .92, 'secondary: ' + motif2, ha='center')
            pylab.figtext(.75, .97, 'ln evidence=%.1f' %
                          log_evidence, ha='center')
            pylab.figtext(.75, .92, 'total=%d' %
                          spacing_counts.sum(), ha='center')
            filename = os.path.join(output_dir, tag)
            rel_filename = os.path.join(rel_output_dir, tag)
            pylab.savefig(filename + '.png')
            pylab.savefig(filename + '.eps')
            pylab.close()
        logger.info(
            'LLR:%8.1f; total:%6d; %30s - %-30s; bar chart: %s',
            log_evidence, spacing_counts.sum(), motif1, motif2, rel_filename
        )
        already_examined[(motif1, motif2)] = rel_filename



#
# Make HTML page showing spacings
#
try:
    from jinja2 import Environment, PackageLoader
    env = Environment(loader=PackageLoader('stempy', 'templates'))
    template = env.get_template('spacing.html')

    # Copy the static info
    static_dir = os.path.join(options.results_dir, 'static')
    html_copy_static(static_dir)

    # write the HTML
    filename = os.path.join(
        options.results_dir, 'spacings-max-%d.html' % options.max_distance)
    logger.info('Writing STEME spacing statistics as HTML to %s', filename)
    with open(filename, 'w') as f:
        variables = {
            'log_evidences': log_evidences,
            'filenames': already_examined,
        }
        f.write(template.render(**variables))

except ImportError, e:
    logger.warning('Import failed, could not write HTML:\n%s', str(e))

