#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

""" Plot a single BOSS spectrum.
"""

from __future__ import division, print_function

import argparse

import os.path
import math
import re

import numpy as np
import numpy.ma
import matplotlib.pyplot as plt
import astropy.table

import bossdata

from six import iteritems, string_types


def auto_range(range_str, abs_lo, abs_hi, default_limits, data_lo=None, data_hi=None):
    """
    Interpret an axis range specification.

    Percentage lo/hi limits are interpreted as percentiles when data_lo/data_hi is
    provided and the limit value is 0-100%.  Otheriwse, the limit value is calculated
    as::

        abs_lo + (abs_hi - abs_lo) * limit/100

    Args:
        range_str(str): A range specification of the form [ <lo> : <hi> ].
        abs_lo(float): The absolute lower bound for data on this axis.
        abs_hi(float): The absolute upper bound for data on this axis.
        default_limits(str): A default range specification of the form [ <lo> : <hi> ].
        data_lo(numpy.ndarray): An array of values whose percentile is used to
            calculate lo limits when the limit value is 0-100%.
        data_hi(numpy.ndarray): An array of values whose percentile is used to
            calculate hi limits when the limit value is 0-100%.
    """
    if data_hi is None:
        data_hi = data_lo

    def _limit_parser(limit, default_limit, data):
        try:
            if limit.endswith('%'):
                limit_percent = float(limit.rstrip('%'))
                if data is not None and limit_percent > 0 and limit_percent < 100:
                    value = np.percentile(data, limit_percent)
                else:
                    value = abs_lo + (abs_hi - abs_lo) * limit_percent / 100.0
            elif limit == '':
                value = default_limit
            else:
                value = float(limit)
            if math.isinf(value) or math.isnan(value):
                raise ValueError()
            return value
        except ValueError:
            raise ValueError('Invalid limit specified ({}).'.format(limit))

    range_pattern = re.compile('\s*\[\s*(.*?)\s*:\s*(.*?)\s*\]\s*')

    if isinstance(default_limits, string_types):
        parsed = re.match(range_pattern, default_limits)
        if not parsed:
            raise ValueError('Default range {} does not have the format [<lo>:<hi>].'.format(
                range_str))
        def_lo, def_hi = parsed.groups()
        default_lo = _limit_parser(def_lo, 0, data_lo)
        default_hi = _limit_parser(def_hi, 0, data_hi)
    else:  # Assume e.g. limits from get_xlim
        default_lo, default_hi = default_limits

    # The range must be have the format [ <lo> : <hi> ]
    parsed = re.match(range_pattern, range_str)
    if not parsed:
        raise ValueError('Range {} does not have the format [<lo>:<hi>].'.format(range_str))
    lo, hi = parsed.groups()
    new_lo = _limit_parser(lo, default_lo, data_lo)
    new_hi = _limit_parser(hi, default_hi, data_hi)
    if new_hi <= new_lo:
        raise ValueError('Must have min < max in range {}.'.format(range_str))
    return new_lo, new_hi


def print_mask_summary(label, mask_values):
    if np.any(mask_values):
        print('{0} pixel mask summary:'.format(label))
        bit_summary = bossdata.bits.summarize_bitmask_values(
            bossdata.bits.SPPIXMASK, mask_values)
        for bit_name, bit_count in iteritems(bit_summary):
            print('{0:5d} {1}'.format(bit_count, bit_name))
    else:
        print('No pixels masked.')


def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Plot a single BOSS spectrum.')
    parser.add_argument(
        '--verbose', action='store_true', help='Provide verbose output.')
    parser.add_argument(
        '--plate', type=int, default=6641, metavar='PLATE',
        help='Plate number of spectrum to plot.')
    parser.add_argument(
        '--mjd', type=int, default=None, metavar='MJD',
        help='MJD of plate observation to use (can be omitted if only one value is possible)')
    parser.add_argument(
        '--fiber', type=int, default=30, metavar='FIBER',
        help='Fiber number identifying the spectrum of the requested PLATE-MJD to plot.')
    parser.add_argument(
        '--exposure', type=int, default=None, metavar='EXP',
        help='Exposure sequence number starting from 0, or plot the coadd if not set.')
    parser.add_argument(
        '--band', type=str, choices=['blue', 'red', 'both'],
        default='both', help='band to use when plotting a single exposure.')
    parser.add_argument(
        '--allow-mask', type=str, default=None,
        help='SPPIXMASK bit names to allow in valid data. Separate multiple names with |.')
    parser.add_argument(
        '--show-invalid', action='store_true',
        help='Include pixels with invalid data in the plot (implies --scatter).')
    parser.add_argument(
        '--platefile', action='store_true',
        help='Plot the combined spectrum from an spPlate file.')
    parser.add_argument(
        '--frame', action='store_true',
        help='Plot a single exposure spectrum from an uncalibrated spFrame file.')
    parser.add_argument(
        '--cframe', action='store_true',
        help='Plot a single exposure spectrum from a calibrated spCFrame file.')
    parser.add_argument(
        '--save-plot', type=str, default=None, const='', nargs='?', metavar='FILE',
        help='Save the generated plot to specified name '
             '(uses {plot-label}.png if name omitted).')
    parser.add_argument(
        '--save-data', type=str, default=None, const='', nargs='?', metavar='FILE',
        help='Save the spectrum data to specified name '
             '(uses {plot-label}.dat if name  omitted).')
    parser.add_argument(
        '--no-display', action='store_true',
        help='Do not display the image on screen (useful for batch processing).')
    parser.add_argument(
        '--scatter', action='store_true',
        help='Show scatter of flux instead of a flux error band.')
    parser.add_argument(
        '--show-mask', action='store_true',
        help='Indicate pixels with invalid data using vertical lines.')
    parser.add_argument(
        '--show-dispersion', action='store_true',
        help='Show the wavelength dispersion using the right-hand axis.')
    parser.add_argument(
        '--show-sky', action='store_true',
        help='Show the subtracted sky flux instead of the object flux.')
    parser.add_argument(
        '--add-sky', action='store_true',
        help='Add the subtracted sky to the object flux (overrides show-sky).')
    parser.add_argument(
        '--flux-range', type=str, default='[0.5%:99.5%]', metavar='[MIN:MAX]',
        help='Custom flux axis range, e.g. [0:100], [1%%:99%%], [0:].')
    parser.add_argument(
        '--wlen-range', type=str, default=None, metavar='[MIN:MAX]',
        help='Custom wavelength axis range, e.g. [4000:8000], [-5%%:105%%].')
    parser.add_argument(
        '--wdisp-range', type=str, default=None, metavar='[MIN:MAX]',
        help='Custom dispersion axis range, e.g. [0:2], [0:].')
    parser.add_argument(
        '--label-pos', metavar='POS', default='top-left',
        choices=['none', 'top-left', 'top-center', 'top-right', 'center-left', 'center-center',
                 'center-right', 'bottom-left', 'bottom-center', 'bottom-right'],
        help='Label position, e.g. top-right, bottom-left, center-center.')
    parser.add_argument(
        '--no-grid', action='store_true', help='Do not draw wavelength grid lines.')
    parser.add_argument(
        '--figsize', type=str, default='12,8', metavar='WIDTH,HEIGHT',
        help='Figure dimensions in inches.')
    args = parser.parse_args()

    if args.exposure is None:
        if args.frame or args.cframe:
            print('Coadds not available from frame and cframe files.')
            return -1
        if args.band is not 'both':
            print('Ignoring band = "{0}" for combined spectrum.'.format(args.band))
            args.band = 'both'
    else:
        if args.platefile:
            print('Single exposures not available from plate files.')
            return -1

    if args.allow_mask is None:
        pixel_quality_mask = None
    else:
        try:
            pixel_quality_mask = bossdata.bits.bitmask_from_text(
                bossdata.bits.SPPIXMASK, args.allow_mask)
        except ValueError as e:
            print(e)
            return -1

    try:
        finder = bossdata.path.Finder(verbose=args.verbose)
        mirror = bossdata.remote.Manager(verbose=args.verbose)
    except ValueError as e:
        print(e)
        return -1

    if args.mjd is None:
        mjd_list = bossdata.meta.get_plate_mjd_list(
            args.plate, finder=finder, mirror=mirror)
        if len(mjd_list) == 0:
            print('Plate {0:d} has not been observed with good data quality.'.format(
                args.plate))
            return -1
        elif len(mjd_list) > 1:
            print('Plate {0:d} has been observed on MJDs {1:s}.'.format(
                args.plate, ','.join(map(str, mjd_list))))
            print('Select one of these using the --mjd command-line argument.')
            return -1
        else:
            args.mjd = mjd_list[0]

    if args.verbose:
        observation = '{:04d}-{:5d}-{:04d}'.format(args.plate, args.mjd, args.fiber)
        if args.exposure is None:
            print('Plotting combined spectrum for {}.'.format(observation))
        else:
            print('Plotting exposure[{}] spectrum for {}.'.format(args.exposure, observation))

    # Load spectra into memory, downloading if necessary.
    try:
        if args.frame or args.cframe:
            frames = {}
            frame_path = finder.get_plate_path(plate=args.plate)
            plan_path = finder.get_plate_plan_path(plate=args.plate, mjd=args.mjd)
            plan = bossdata.plate.Plan(mirror.get(plan_path))
            num_exposures = plan.num_science_exposures
            if args.verbose:
                print('Exposure summary:')
                print(plan.exposure_table)
            ftype = 'spCFrame' if args.cframe else 'spFrame'
            if args.band in ('red', 'both'):
                red_name = plan.get_exposure_name(
                    args.exposure, 'red', args.fiber, ftype)
                if red_name is None:
                    print('Red band data not available.')
                    return -1
                frames['red'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, red_name)))
            if args.band in ('blue', 'both'):
                blue_name = plan.get_exposure_name(
                    args.exposure, 'blue', args.fiber, ftype)
                if blue_name is None:
                    print('Blue band data not available.')
                    return -1
                frames['blue'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, blue_name)))
        elif args.platefile:
            remote_path = finder.get_plate_spec_path(plate=args.plate, mjd=args.mjd)
            local_path = mirror.get(remote_path)
            platefile = bossdata.plate.PlateFile(local_path)
            num_exposures = platefile.num_exposures
            if args.verbose:
                print('Exposure summary:')
                print(platefile.exposures.table)
        else:
            lite = (args.exposure is None)
            remote_paths = [finder.get_spec_path(plate=args.plate, mjd=args.mjd,
                            fiber=args.fiber, lite=lite)]
            if lite:    # If lite, we can use the Full file if it exists but lite does not
                remote_paths.append(finder.get_spec_path(
                    plate=args.plate, mjd=args.mjd, fiber=args.fiber, lite=False))
            local_paths = []
            local_path = mirror.get(remote_paths, local_paths=local_paths)
            specfile = bossdata.spec.SpecFile(local_path)
            num_exposures = specfile.num_exposures
            if args.verbose:
                print('Exposure summary:')
                print(specfile.exposures.table)
                if local_path != local_paths[0]:
                    print('Using local full spec file instead of downloading lite file.')
    except (RuntimeError, ValueError) as e:
        print(str(e))
        return -1

    # Check for an invalid exposure index.
    if args.exposure is not None and (args.exposure < 0 or args.exposure >= num_exposures):
        print('The value --exposure {} is outside the allowed range 0-{} for this target.'
            .format(args.exposure, num_exposures - 1))
        if not args.verbose:
            print('Use the --verbose option for details on the available exposures.')
        return -1

    # Initialize the plot.
    try:
        figsize_x, figsize_y = map(float, args.figsize.split(','))
        if figsize_x < 1 or figsize_y < 1:
            print('Minimum --figsize is 1,1.')
            return -1
        figure = plt.figure(figsize=(figsize_x, figsize_y), tight_layout=True)
    except ValueError as e:
        print('Invalid option --figsize {} (expected WIDTH,HEIGHT in inches).'.format(
            args.figsize))
        return -1
    left_axis = plt.gca()
    figure.set_facecolor('white')
    plt.xlabel('Wavelength ($\AA$)')
    if args.frame:
        left_axis.set_ylabel('Flux (electrons)')
    else:
        left_axis.set_ylabel('Flux ($\\times 10^{-17}$ erg/s/cm$^{2}$/$\AA$)')
    if args.show_dispersion:
        right_axis = left_axis.twinx()
        right_axis.set_ylabel('Dispersion (pixels)')

    # We will potentially plot two spectra.
    spectra = []
    plot_colors = []
    data_args = dict(
        include_wdisp=args.show_dispersion, include_sky=args.show_sky or args.add_sky)
    if args.exposure is None:
        if args.platefile:
            fibers = np.array([args.fiber], dtype=int)
            spectra.append(platefile.get_valid_data(
                fibers, pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('black')
            if args.verbose:
                print('Showing coadd of {:d} blue+red exposures.'.format(
                    platefile.num_exposures))
                print_mask_summary('Coadd (AND)', platefile.get_pixel_masks(fibers)[0])
        else:
            spectra.append(specfile.get_valid_data(
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('black')
            if args.verbose:
                print('Showing coadd of {:d} blue+red exposures.'.format(
                    specfile.num_exposures))
                print_mask_summary('Coadd (AND)', specfile.get_pixel_mask())
    elif args.frame or args.cframe:
        fibers = np.array([args.fiber], dtype=int)
        exposure_id = plan.exposures['science'][args.exposure]['EXPID']
        if args.verbose:
            print('Showing exposure {:08d}.'.format(exposure_id))
        if args.band in ('blue', 'both'):
            spectra.append(frames['blue'].get_valid_data(
                fibers, pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue', frames['blue'].get_pixel_masks(fibers)[0])
        if args.band in ('red', 'both'):
            spectra.append(frames['red'].get_valid_data(
                fibers, pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red', frames['red'].get_pixel_masks(fibers)[0])
    else:
        # Which spectrograph is this fiber on for this plate?
        num_fibers = bossdata.plate.get_num_fibers(args.plate)
        spec_index = '1' if args.fiber <= num_fibers // 2 else '2'
        if args.band in ('blue', 'both'):
            camera = 'b' + spec_index
            exposure_id = specfile.exposures.get_info(args.exposure, camera)['science']
            if args.verbose:
                print('Showing blue exposure {:08d}.'.format(exposure_id))
            spectra.append(specfile.get_valid_data(
                args.exposure, camera, pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue', specfile.get_pixel_mask(args.exposure, camera))
        if args.band in ('red', 'both'):
            camera = 'r' + spec_index
            exposure_id = specfile.exposures.get_info(args.exposure, camera)['science']
            if args.verbose:
                print('Showing red exposure {:08d}.'.format(exposure_id))
            spectra.append(specfile.get_valid_data(
                args.exposure, camera, pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red', specfile.get_pixel_mask(args.exposure, camera))

    # Determine this plot's label and use it for the plot window title.
    plot_label = '{plate}-{mjd}-{fiber}'.format(
        plate=args.plate, mjd=args.mjd, fiber=args.fiber)
    if args.exposure is not None:
        plot_label += '-{:08d}'.format(exposure_id)
    figure.canvas.set_window_title(plot_label)

    # Save the spectrum data, if requested.
    if args.save_data:
        if len(spectra) > 1:
            print('WARNING: saving data for the first spectrum only.')
        save_name = args.save_data
        if save_name == '':
            save_name = plot_label + '.dat'
        # Only save un-masked rows.
        valid = ~(spectra[0]['flux'].mask)
        table = astropy.table.Table(spectra[0][valid])
        try:
            table.write(save_name, format='ascii.basic', overwrite=True)
        except IOError as e:
            print('Unable to save data: {}.'.format(str(e)))
            return -1
        if args.verbose:
            print('Saved data to {}'.format(save_name))

    wlen_min, wlen_max = +np.inf, -np.inf
    flux_lo, flux_hi = np.array([], dtype=float), np.array([], dtype=float)
    wdisp_data = np.array([], dtype=float)
    wlen_data = np.array([], dtype=float)
    x_mask = []
    for data, plot_color in zip(spectra, plot_colors):

        wlen, dflux = data['wavelength'][:], data['dflux'][:]
        if args.add_sky:
            flux = data['sky'][:] + data['flux'][:]
        elif args.show_sky:
            flux = data['sky'][:]
        else:
            flux = data['flux'][:]

        if args.scatter:
            left_axis.scatter(wlen, flux, color=plot_color, marker='.', s=0.1)
        elif args.show_invalid:
            left_axis.scatter(wlen.data, flux.data, color=plot_color, marker='.', s=0.1)
        else:
            left_axis.fill_between(
                wlen, flux - dflux, flux + dflux, color=plot_color, alpha=0.5)

        num_masked = np.count_nonzero(flux.mask)
        if args.show_mask and num_masked > 0:
            bad_pixels = np.where(flux.mask)
            for x in data.data['wavelength'][bad_pixels]:
                x_mask.extend([x, x, None])

        if args.show_dispersion:
            wdisp = data['wdisp'][:]
            if args.show_invalid:
                right_axis.plot(wlen.data, wdisp.data, ls='-', color=plot_color)
                wdisp_data = np.append(wdisp_data, wdisp.data)
            else:
                right_axis.plot(wlen, wdisp, ls='-', color=plot_color)
                wdisp_data = numpy.ma.append(wdisp_data, data['wdisp'][:])

        # Update the plot wavelength limits to include this data.
        if args.show_invalid:
            wlen_min = min(wlen_min, np.min(wlen.data))
            wlen_max = max(wlen_max, np.max(wlen.data))
        else:
            # Only use pixels with valid data to update the limits.
            wlen_min = min(wlen_min, np.ma.min(wlen))
            wlen_max = max(wlen_max, np.ma.max(wlen))

        # Update the list of fluxes that we will use to auto-range the vertical scale.
        if args.show_invalid:
            flux_lo = numpy.append(flux_lo, flux.data)
        else:
            flux_lo = numpy.ma.append(flux_lo, flux - dflux)
            flux_hi = numpy.ma.append(flux_hi, flux + dflux)

        # Update the list of wavelengths that we will use to auto-range the horizontal scale.
        wlen_data = numpy.ma.append(wlen_data, wlen)

    # Is there any valid data to plot?
    if wlen_min == np.inf:
        print('There is no valid data. Try the --allow-mask and --show-invalid options?')
        return 0

    # The x-axis limits are reset by the twinx() function so we set them here.
    plt.xlim(wlen_min, wlen_max)
    # Set the wavelength axis range.
    if args.wlen_range:
        try:
            left_axis.set_xlim(*auto_range(
                args.wlen_range, wlen_min, wlen_max,
                parser.get_default('wlen_range') or left_axis.get_xlim()))
        except ValueError as e:
            print(e)
            return -2

    # Set the flux axis range.
    if args.flux_range:
        try:
            if args.show_invalid:
                left_axis.set_ylim(*auto_range(
                    args.flux_range, np.min(flux_lo), np.max(flux_lo),
                    parser.get_default('flux_range') or left_axis.get_ylim(),
                    data_lo=flux_lo, data_hi=flux_lo))
            else:
                valid = (~flux_lo.mask & ~flux_hi.mask &
                         (wlen_data > wlen_min) & (wlen_data < wlen_max))
                left_axis.set_ylim(*auto_range(
                    args.flux_range, np.min(flux_lo[valid]), np.max(flux_hi[valid]),
                    parser.get_default('flux_range') or left_axis.get_ylim(),
                    data_lo=flux_lo[valid], data_hi=flux_hi[valid]))
        except ValueError as e:
            print(e)
            return -2

    # Set the wavelength dispersion axis range.
    if args.show_dispersion and args.wdisp_range:
        try:
            valid = ~wdisp_data.mask & (wlen_data > wlen_min) & (wlen_data < wlen_max)
            right_axis.set_ylim(*auto_range(
                args.wdisp_range, np.min(wdisp_data[valid]), np.max(wdisp_data[valid]),
                parser.get_default('wdisp_range') or right_axis.get_ylim(),
                wdisp_data[valid]))
        except ValueError as e:
            print(e)
            return -2

    # Draw mask lines that extend the full range of the left axis.
    if args.show_mask and len(x_mask) > 0:
        ymin, ymax = left_axis.get_ylim()
        assert len(x_mask) % 3 == 0, 'Internal error: len(x_mask)%3 != 0'
        y_mask = [ymin, ymax, None] * (len(x_mask) // 3)
        left_axis.plot(x_mask, y_mask, '-', color='DarkGreen', alpha=0.3)

    # Annotate the plot with a label identifying this spectrum.
    if args.label_pos != 'none':
        valign, halign = args.label_pos.split('-')
        x = dict(left=0.02, center=0.5, right=0.98)[halign]
        y = dict(top=0.98, center=0.5, bottom=0.02)[valign]
        plt.annotate(
            plot_label, xy=(x, y), xycoords='axes fraction',
            xytext=(x, y), textcoords='axes fraction',
            horizontalalignment=halign, verticalalignment=valign, fontsize=20)

    if not args.no_grid:
        left_axis.xaxis.grid()

    if args.save_plot is not None:
        save_name = args.save_plot
        if save_name == '':
            save_name = plot_label + '.png'
        try:
            figure.savefig(save_name)
        except IOError as e:
            print('Unable to save plot: {}.'.format(str(e)))
            return -1
        if args.verbose:
            print('Saved plot to {}'.format(save_name))
    if not args.no_display:
        plt.show()
    plt.close()

if __name__ == '__main__':
    main()
