#!/home/hancock/V-Py/bin/python
"""
 Tool for making residual images with Aegean tables as input
"""
__author__ = 'Paul Hancock'
__version__ = 'v0.2.5'
__date__ = '2016-09-08'

from AegeanTools import catalogs, wcs_helpers, fitting
from astropy.io import fits
import logging
import numpy as np
from optparse import OptionParser
import sys

# global constants
FWHM2CC = 1 / (2 * np.sqrt(2 * np.log(2)))


def load_sources(filename):
    """
    Open a file, read contents, return a list of all the sources in that file.
    @param filename:
    @return: list of OutputSource objects
    """
    catalog = catalogs.table_to_source_list(catalogs.load_table(filename))
    logging.info("read {0} sources from {1}".format(len(catalog), filename))
    return catalog


def make_model(sources, shape, wcshelper):
    """

    @param sources: a list of AegeanTools.models.SimpleSource objects
    @param shape: the shape of the input (and output) image
    @param wcshelper: an AegeanTools.wcs_helpers.WCSHelper object corresponding to the input image
    @return:
    """

    # Model array
    m = np.zeros(shape, dtype=np.float32)
    factor = 5

    i_count = 0
    for src in sources:
        xo, yo, sx, sy, theta = wcshelper.sky2pix_ellipse([src.ra, src.dec], src.a/3600, src.b/3600, src.pa)
        phi = np.radians(theta)

        # skip sources that have a center that is outside of the image
        if not 0 < xo < shape[0]:
            continue
        if not 0 < yo < shape[1]:
            continue

        # pixels over which this model is calculated
        xoff = factor*(abs(sx*np.cos(phi)) + abs(sy*np.sin(phi)))
        xmin = xo - xoff
        xmax = xo + xoff

        yoff = factor*(abs(sx*np.sin(phi)) + abs(sy*np.cos(phi)))
        ymin = yo - yoff
        ymax = yo + yoff

        # clip to the image size
        ymin = max(np.floor(ymin), 0)
        ymax = min(np.ceil(ymax), shape[1])

        xmin = max(np.floor(xmin), 0)
        xmax = min(np.ceil(xmax), shape[0])

        if logging.getLogger().isEnabledFor(logging.DEBUG):
            logging.debug("Source ({0},{1})".format(src.island, src.source))
            logging.debug(" xo, yo: {0} {1}".format(xo, yo))
            logging.debug(" sx, sy: {0} {1}".format(sx, sy))
            logging.debug(" theta, phi: {0} {1}".format(theta, phi))
            logging.debug(" xoff, yoff: {0} {1}".format(xoff, yoff))
            logging.debug(" xmin, xmax, ymin, ymax: {0}:{1} {2}:{3}".format(xmin, xmax, ymin, ymax))

        # positions for which we want to make the model
        x, y = np.mgrid[xmin:xmax, ymin:ymax]
        x = map(int, x.ravel())
        y = map(int, y.ravel())

        # TODO: understand why xo/yo -1 is needed
        model = fitting.elliptical_gaussian(x, y, src.peak_flux, xo-1, yo-1, sx*FWHM2CC, sy*FWHM2CC, theta)
        m[x, y] += model
        i_count += 1
    logging.info("modeled {0} sources".format(i_count))
    return m


def make_residual(fitsfile, catalog, rfile, mfile=None, add=False):
    """

    @param fitsfile: Input fits image filename
    @param catalog: Input catalog filename of a type supported by Aegean
    @param rfile: Residual image filename
    @param mfile: model image filename
    @param add: add the model instead of subtracting it
    @return:
    """
    source_list = load_sources(catalog)
    hdulist = fits.open(fitsfile)
    data = hdulist[0].data
    header = hdulist[0].header

    wcshelper = wcs_helpers.WCSHelper.from_header(header)

    model = make_model(source_list, data.shape, wcshelper)

    if add:
        residual = data + model
    else:
        residual = data - model

    hdulist[0].data = residual
    hdulist.writeto(rfile, clobber=True)
    logging.info("wrote residual to {0}".format(rfile))
    if mfile is not None:
        hdulist[0].data = model
        hdulist.writeto(mfile, clobber=True)
        logging.info("wrote model to {0}".format(mfile))
    return


if __name__ == "__main__":
    usage = "usage: %prog -c input.vot -f image.fits -r residual.fits [-m model.fits]"
    parser = OptionParser(usage=usage)
    parser.add_option("-c", "--catalog", dest='catalog', default=None,
                      help="Catalog in a format that Aegean understands.")
    parser.add_option("-f", "--fitsimage", dest='fitsfile', default=None,
                      help="Input fits file.")
    parser.add_option("-r", "--residual", dest='rfile', default=None,
                      help="Output residual fits file.")
    parser.add_option('-m', "--model", dest='mfile', default=None,
                      help="Output model file [optional].")
    parser.add_option('--add', dest='add', default=False, action='store_true',
                      help="Add components instead of subtracting them.")
    parser.add_option('--debug', dest='debug', action='store_true', default=False,
                      help="Debug mode.")

    (options, args) = parser.parse_args()

    logging_level = logging.DEBUG if options.debug else logging.INFO
    logging.basicConfig(level=logging_level, format="%(process)d:%(levelname)s %(message)s")
    logging.info("This is AeRes {0}-({1})".format(__version__, __date__))

    if options.catalog is None:
        logging.error("input catalog is required")
        sys.exit(1)
    if options.fitsfile is None:
        logging.error("input fits file is required")
        sys.exit(1)
    if options.rfile is None:
        logging.error("output residual filename is required")
        sys.exit(1)

    logging.info("Using {0} and {1} to make {2}".format(options.fitsfile, options.catalog, options.rfile))
    if options.mfile is not None:
        logging.info(" and writing model to {0}".format(options.mfile))
    make_residual(options.fitsfile, options.catalog, options.rfile,
                  mfile=options.mfile, add=options.add)
    sys.exit(0)
