#!/usr/bin/env python
#
# Copyright John Reid 2016
#

"""
Reads in a MEME motif file applies some changes and writes it out.
"""

import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()

import os
import stempy
from stempy.meme_parse import do_parse_and_extract, write_pwms
from optparse import OptionParser


def print_heading(out, heading, underline):
    print >>out, heading
    print >>out, underline * len(heading.strip())
    print >>out

def formatname(x):
    "Format the name."
    idx, motif = x
    newname = options.name_format.format(name=motif.name, idx=idx, alt=motif.altname)
    logger.info(newname)
    return motif._replace(name=newname)

def formataltname(x):
    "Format the altname."
    idx, motif = x
    newname = options.altname_format.format(name=motif.name, idx=idx, alt=motif.altname)
    logger.info(newname)
    return motif._replace(altname=newname)

def addpseudocount(motif):
    "Add a pseudocount."
    newvalues = motif.letter_probs.values * motif.letter_probs.nsites + options.add_pseudocount
    newvalues = newvalues / newvalues[0].sum()
    return motif._replace(letter_probs=motif.letter_probs._replace(values=newvalues))

usage = 'USAGE: %prog [options] <motifs filename>'
parser = OptionParser(usage=usage)
parser.add_option(
    "-f",
    "--name-format",
    default=None,
    help="Replace names using format NAME",
    metavar="NAME"
)
parser.add_option(
    "-a",
    "--altname-format",
    default=None,
    help="Replace alternative names using ALTNAME",
    metavar="ALTNAME"
)
parser.add_option(
    "-p",
    "--add-pseudocount",
    default=0.,
    type=float,
    help="Add a PSEUDOCOUNT to the letter probabilities",
    metavar="PSEUDOCOUNT"
)
options, args = parser.parse_args()


#
# Parse the motifs
#
motifs_filename = args[0]
logger.info('Parsing motifs: %s', motifs_filename)
extracted = do_parse_and_extract(open(motifs_filename).read())

outputmotifs = iter(extracted.motifs)

#
# Apply name formatting
#
if None != options.name_format:
    logger.info('Applying name format: %s', options.name_format)
    outputmotifs = map(formatname, enumerate(outputmotifs))

#
# Apply altname formatting
#
if None != options.altname_format:
    logger.info('Applying altname format: %s', options.altname_format)
    outputmotifs = map(formataltname, enumerate(outputmotifs))

#
# Apply pseudocounts
#
if None != options.add_pseudocount:
    logger.info('Applying pseudocount: %f', options.add_pseudocount)
    outputmotifs = map(addpseudocount, outputmotifs)

#
# Write the motifs
#
logger.info('Writing motifs')
extracted = extracted._replace(motifs=list(outputmotifs))
write_pwms(sys.stdout, extracted)
