#!python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

""" Plot a single BOSS spectrum.
"""

from __future__ import division,print_function

from astropy.utils.compat import argparse

import os.path
import math
import re

import numpy as np
import numpy.ma
import matplotlib.pyplot as plt
import astropy.table

import bossdata.path
import bossdata.remote
import bossdata.spec
import bossdata.bits
import bossdata.plate

def auto_range(range_str, abs_lo, abs_hi, default_limits, data_lo=None, data_hi=None):
    """
    Interpret an axis range specification.

    Percentage lo/hi limits are interpreted as percentiles when data_lo/data_hi is
    provided and the limit value is 0-100%.  Otheriwse, the limit value is calculated
    as::

        abs_lo + (abs_hi - abs_lo) * limit/100

    Args:
        range_str(str): A range specification of the form [ <lo> : <hi> ].
        abs_lo(float): The absolute lower bound for data on this axis.
        abs_hi(float): The absolute upper bound for data on this axis.
        default_limits(str): A default range specification of the form [ <lo> : <hi> ].
        data_lo(numpy.ndarray): An array of values whose percentile is used to
            calculate lo limits when the limit value is 0-100%.
        data_hi(numpy.ndarray): An array of values whose percentile is used to
            calculate hi limits when the limit value is 0-100%.
    """
    if data_hi is None:
        data_hi = data_lo
    def _limit_parser(limit, default_limit, data):
        try:
            if limit.endswith('%'):
                limit_percent = float(limit.rstrip('%'))
                if data is not None and limit_percent > 0 and limit_percent < 100:
                    value = np.percentile(data, limit_percent)
                else:
                    value = abs_lo + (abs_hi - abs_lo) * limit_percent / 100.0
            elif limit == '':
                value = default_limit
            else:
                value = float(limit)
            if math.isinf(value) or math.isnan(value):
                raise ValueError()
            return value
        except ValueError:
            raise ValueError('Invalid limit specified ({}).'.format(limit))

    range_pattern = re.compile('\s*\[\s*(.*?)\s*:\s*(.*?)\s*\]\s*')

    if isinstance(default_limits, basestring):
        parsed = re.match(range_pattern, default_limits)
        if not parsed:
            raise ValueError('Default range {} does not have the format [<lo>:<hi>].'.format(
                range_str))
        def_lo, def_hi = parsed.groups()
        default_lo = _limit_parser(def_lo, 0, data_lo)
        default_hi = _limit_parser(def_hi, 0, data_hi)
    else: #Assume e.g. limits from get_xlim
        default_lo, default_hi = default_limits

    # The range must be have the format [ <lo> : <hi> ]
    parsed = re.match(range_pattern, range_str)
    if not parsed:
        raise ValueError('Range {} does not have the format [<lo>:<hi>].'.format(range_str))
    lo, hi = parsed.groups()
    new_lo = _limit_parser(lo, default_lo, data_lo)
    new_hi = _limit_parser(hi, default_hi, data_hi)
    if new_hi <= new_lo:
        raise ValueError('Must have min < max in range {}.'.format(range_str))
    return new_lo, new_hi

def print_mask_summary(label, mask_values):
    if np.any(mask_values):
        print('{0} pixel mask summary:'.format(label))
        bit_summary = bossdata.bits.summarize_bitmask_values(
            bossdata.bits.SPPIXMASK,mask_values)
        for bit_name,bit_count in bit_summary.iteritems():
            print('{0:5d} {1}'.format(bit_count,bit_name))
    else:
        print('No pixels masked.')

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter,
        description = 'Plot a single BOSS spectrum.')
    parser.add_argument('--verbose', action = 'store_true',
        help = 'Provide verbose output.')
    parser.add_argument('--plate',type = int, default = 6641, metavar = 'PLATE',
        help = 'Plate number of spectrum to plot.')
    parser.add_argument('--mjd',type = int, default = 56383, metavar = 'MJD',
        help = 'Modified Julian date of plate observation to use.')
    parser.add_argument('--fiber',type = int,default = 30, metavar = 'FIBER',
        help = 'Fiber number identifying the spectrum of the requested PLATE-MJD to plot.')
    parser.add_argument('--exposure',type = int,default = None, metavar = 'EXP',
        help = 'Exposure sequence number starting from 0, or plot the coadd if not set.')
    parser.add_argument('--camera',type = str, choices = ['blue','red','both'],
        default = 'both', help = 'Camera to use when plotting a single exposure.')
    parser.add_argument('--allow-mask', type = str, default = None,
        help = 'SPPIXMASK bit names to allow in valid data. Separate multiple names with |.')
    parser.add_argument('--show-invalid', action='store_true',
        help = 'Include pixels with invalid data in the plot (implies --scatter).')
    parser.add_argument('--platefile', action='store_true',
        help = 'Plot the combined spectrum from an spPlate file.')
    parser.add_argument('--frame', action='store_true',
        help = 'Plot a single exposure spectrum from an uncalibrated spFrame file.')
    parser.add_argument('--cframe', action='store_true',
        help = 'Plot a single exposure spectrum from a calibrated spCFrame file.')
    parser.add_argument('--save-plot', type=str, default=None, const='', nargs='?',
        metavar='FILE', help = ('Save the generated plot to specified name ' +
            '(uses {plot-label}.png if name omitted).'))
    parser.add_argument('--save-data', type=str, default=None, const='', nargs='?',
        metavar='FILE', help = ('Save the spectrum data to specified name ' +
            '(uses {plot-label}.dat if name omitted).'))
    parser.add_argument('--no-display', action = 'store_true',
        help = 'Do not display the image on screen (useful for batch processing).')
    parser.add_argument('--scatter', action = 'store_true',
        help = 'Show scatter of flux instead of a flux error band.')
    parser.add_argument('--show-mask', action = 'store_true',
        help = 'Indicate pixels with invalid data using vertical lines.')
    parser.add_argument('--show-dispersion', action = 'store_true',
        help = 'Show the wavelength dispersion using the right-hand axis.')
    parser.add_argument('--show-sky', action = 'store_true',
        help = 'Show the subtracted sky flux instead of the object flux.')
    parser.add_argument('--add-sky', action = 'store_true',
        help = 'Add the subtracted sky to the object flux (overrides show-sky).')
    parser.add_argument('--flux-range', type=str, default='[0.5%:99.5%]', metavar='[MIN:MAX]',
        help='Custom flux axis range, e.g. [0:100], [1%%:99%%], [0:].')
    parser.add_argument('--wlen-range', type=str, default=None, metavar='[MIN:MAX]',
        help='Custom wavelength axis range, e.g. [4000:8000], [-5%%:105%%].')
    parser.add_argument('--wdisp-range', type=str, default=None, metavar='[MIN:MAX]',
        help='Custom dispersion axis range, e.g. [0:2], [0:].')
    parser.add_argument('--label-pos', metavar='POS', default='top-left',choices=[
        'none','top-left','top-center','top-right','center-left','center-center',
        'center-right','bottom-left','bottom-center','bottom-right'],
        help='Label position, e.g. top-right, bottom-left, center-center.')
    parser.add_argument('--no-grid', action='store_true',
        help='Do not draw wavelength grid lines.')
    parser.add_argument('--figsize', type=str, default='12,8', metavar='WIDTH,HEIGHT',
        help='Figure dimensions in inches.')
    args = parser.parse_args()

    if args.exposure is None:
        if args.frame or args.cframe:
            print('Coadds not available from frame and cframe files.')
            return -1
        if args.camera is not 'both':
            print('Ignoring camera = "{0}" for combined spectrum.'.format(args.camera))
            args.camera = 'both'
    else:
        if args.platefile:
            print('Single exposures not available from plate files.')
            return -1

    if args.allow_mask is None:
        pixel_quality_mask = None
    else:
        try:
            pixel_quality_mask = bossdata.bits.bitmask_from_text(
                bossdata.bits.SPPIXMASK,args.allow_mask)
        except ValueError as e:
            print(e)
            return -1

    try:
        finder = bossdata.path.Finder(verbose=args.verbose)
        mirror = bossdata.remote.Manager(verbose=args.verbose)
    except ValueError as e:
        print(e)
        return -1

    if args.verbose:
        observation = '{:04d}-{:5d}-{:04d}'.format(args.plate, args.mjd, args.fiber)
        if args.exposure is None:
            print('Plotting combined spectrum for {}.'.format(observation))
        else:
            print('Plotting exposure {} spectrum for {}.'.format(args.exposure, observation))

    # Load spectra into memory, downloading if necessary.
    try:
        if args.frame or args.cframe:
            frames = {}
            frame_path = finder.get_plate_path(plate=args.plate)
            plan_path = finder.get_plate_plan_path(plate=args.plate, mjd=args.mjd)
            plan = bossdata.plate.Plan(mirror.get(plan_path))
            num_exposures = plan.num_science_exposures
            if args.verbose:
                print('Exposure summary:')
                print(plan.exposure_table)
            spec_index = plan.get_spectrograph_index(args.fiber)
            if args.camera in ('red','both'):
                red_name = plan.get_exposure_name(
                    args.exposure, 'red', args.fiber, calibrated=args.cframe)
                if red_name is None:
                    print('Red camera data not available.')
                    return -1
                frames['red'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, red_name)),
                    index=spec_index, calibrated=args.cframe)
            if args.camera in ('blue','both'):
                blue_name = plan.get_exposure_name(
                    args.exposure, 'blue', args.fiber, calibrated=args.cframe)
                if blue_name is None:
                    print('Blue camera data not available.')
                    return -1
                frames['blue'] = bossdata.plate.FrameFile(
                    mirror.get(os.path.join(frame_path, blue_name)),
                    index=spec_index, calibrated=args.cframe)
        elif args.platefile:
            remote_path = finder.get_plate_spec_path(plate=args.plate, mjd=args.mjd)
            local_path = mirror.get(remote_path)
            platefile = bossdata.plate.PlateFile(local_path)
            num_exposures = len(platefile.exposure_table)
            if args.verbose:
                print('Exposure summary:')
                print(platefile.exposure_table)
        else:
            lite=(args.exposure is None)
            remote_paths = [finder.get_spec_path(plate=args.plate, mjd=args.mjd,
                            fiber=args.fiber, lite=lite)]
            if lite:    # If lite, we can use the Full file if it exists but lite does not
                remote_paths.append(finder.get_spec_path(plate=args.plate, mjd=args.mjd,
                                fiber=args.fiber, lite=False))
            local_paths = []
            local_path = mirror.get(remote_paths, local_paths=local_paths)
            specfile = bossdata.spec.SpecFile(local_path)
            num_exposures = len(specfile.exposure_table)
            if args.verbose:
                print('Exposure summary:')
                print(specfile.exposure_table)
                if local_path != local_paths[0]:
                    print('Using local full spec file instead of downloading lite file.')
    except (RuntimeError, ValueError) as e:
        print(str(e))
        return -1

    # Check for an invalid exposure index.
    if args.exposure is not None and (args.exposure < 0 or args.exposure >= num_exposures):
        print('The value --exposure {} is outside the allowed range 0-{}.'.format(
            args.exposure, num_exposures-1))
        if not args.verbose:
            print('Use the --verbose option for details on the available exposures.')
        return -1

    # Initialize the plot.
    try:
        figsize_x, figsize_y = map(float, args.figsize.split(','))
        if figsize_x < 1 or figsize_y < 1:
            print('Minimum --figsize is 1,1.')
            return -1
        figure = plt.figure(figsize=(figsize_x, figsize_y), tight_layout=True)
    except ValueError as e:
        print('Invalid option --figsize {} (expected WIDTH,HEIGHT in inches).'.format(
            args.figsize))
        return -1
    left_axis = plt.gca()
    figure.set_facecolor('white')
    plt.xlabel('Wavelength ($\AA$)')
    if args.frame:
        left_axis.set_ylabel('Flux (electrons)')
    else:
        left_axis.set_ylabel('Flux ($\\times 10^{-17}$ erg/s/cm$^{2}$/$\AA$)')
    if args.show_dispersion:
        right_axis = left_axis.twinx()
        right_axis.set_ylabel('Dispersion (pixels)')

    # We will potentially plot two spectra.
    spectra = [ ]
    plot_colors = [ ]
    data_args = dict(include_wdisp=args.show_dispersion,
        include_sky=args.show_sky or args.add_sky)
    if args.exposure is None:
        if args.platefile:
            fibers = np.array([args.fiber],dtype=int)
            spectra.append(platefile.get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('black')
            if args.verbose:
                print('Showing coadd of {:d} blue+red exposures.'.format(
                    platefile.num_exposures))
                print_mask_summary('Coadd (AND)', platefile.get_pixel_masks(fibers)[0])
        else:
            spectra.append(specfile.get_valid_data(pixel_quality_mask=pixel_quality_mask,
                **data_args))
            plot_colors.append('black')
            if args.verbose:
                print('Showing coadd of {:d} blue+red exposures.'.format(
                    specfile.num_exposures))
                print_mask_summary('Coadd (AND)',specfile.get_pixel_mask())
    elif args.frame or args.cframe:
        fibers = np.array([args.fiber],dtype=int)
        exposure_id = plan.exposures['science'][args.exposure]['EXPID']
        if args.verbose:
            print('Showing exposure {:08d}.'.format(exposure_id))
        if args.camera in ('blue','both'):
            spectra.append(frames['blue'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',frames['blue'].get_pixel_masks(fibers)[0])
        if args.camera in ('red','both'):
            spectra.append(frames['red'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',frames['red'].get_pixel_masks(fibers)[0])
    else:
        exposure_id = specfile.exposure_table[args.exposure]['exp']
        if args.verbose:
            print('Showing exposure {:08d}.'.format(exposure_id))
        if args.camera in ('blue','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'blue',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',specfile.get_pixel_mask(args.exposure,'blue'))
        if args.camera in ('red','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'red',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',specfile.get_pixel_mask(args.exposure,'red'))

    # Determine this plot's label and use it for the plot window title.
    plot_label = '{plate}-{mjd}-{fiber}'.format(
        plate=args.plate, mjd=args.mjd, fiber=args.fiber)
    if args.exposure is not None:
        plot_label += '-{:08d}'.format(exposure_id)
    figure.canvas.set_window_title(plot_label)

    # Save the spectrum data, if requested.
    if args.save_data:
        if len(spectra) > 1:
            print('WARNING: saving data for the first spectrum only.')
        save_name = args.save_data
        if save_name == '':
            save_name = plot_label + '.dat'
        # Only save un-masked rows.
        valid = ~(spectra[0]['wavelength'].mask)
        table = astropy.table.Table(spectra[0][valid])
        try:
            table.write(save_name, format='ascii.basic')
        except IOError as e:
            print('Unable to save data: {}.'.format(str(e)))
            return -1
        if args.verbose:
            print('Saved data to {}'.format(save_name))

    wlen_min,wlen_max = +np.inf,-np.inf
    flux_lo, flux_hi = np.array([], dtype=float), np.array([], dtype=float)
    wdisp_data = np.array([], dtype=float)
    wlen_data = np.array([], dtype=float)
    x_mask = [ ]
    for data,plot_color in zip(spectra,plot_colors):

        wlen,dflux = data['wavelength'][:],data['dflux'][:]
        if args.add_sky:
            flux = data['sky'][:] + data['flux'][:]
        elif args.show_sky:
            flux = data['sky'][:]
        else:
            flux = data['flux'][:]

        if args.scatter:
            left_axis.scatter(wlen,flux,color=plot_color,marker='.',s=0.1)
        elif args.show_invalid:
            left_axis.scatter(wlen.data, flux.data, color=plot_color, marker='.', s=0.1)
        else:
            left_axis.fill_between(wlen,flux-dflux,flux+dflux,color=plot_color,alpha=0.5)

        num_masked = np.count_nonzero(data.mask)
        if args.show_mask and num_masked > 0:
            bad_pixels = np.where(data.mask)
            for x in data.data['wavelength'][bad_pixels]:
                x_mask.extend([x,x,None])

        if args.show_dispersion:
            wdisp = data['wdisp'][:]
            if args.show_invalid:
                right_axis.plot(wlen.data, wdisp.data, ls='-', color=plot_color)
                wdisp_data = np.append(wdisp_data, wdisp.data)
            else:
                right_axis.plot(wlen,wdisp, ls='-', color=plot_color)
                wdisp_data = numpy.ma.append(wdisp_data, data['wdisp'][:])

        # Update the plot wavelength limits to include this data.
        if args.show_invalid:
            wlen_min = min(wlen_min,np.min(wlen.data))
            wlen_max = max(wlen_max,np.max(wlen.data))
        else:
            # Only use pixels with valid data to update the limits.
            wlen_min = min(wlen_min,np.ma.min(wlen))
            wlen_max = max(wlen_max,np.ma.max(wlen))

        # Update the list of fluxes that we will use to auto-range the vertical scale.
        if args.show_invalid:
            flux_lo = numpy.append(flux_lo, flux.data)
        else:
            flux_lo = numpy.ma.append(flux_lo, flux - dflux)
            flux_hi = numpy.ma.append(flux_hi, flux + dflux)

        # Update the list of wavelengths that we will use to auto-range the horizontal scale.
        wlen_data = numpy.ma.append(wlen_data, wlen)

    # Is there any valid data to plot?
    if wlen_min == np.inf:
        print('There is no valid data. Try the --allow-mask and --show-invalid options?')
        return 0

    # The x-axis limits are reset by the twinx() function so we set them here.
    plt.xlim(wlen_min, wlen_max)
    # Set the wavelength axis range.
    if args.wlen_range:
        try:
            left_axis.set_xlim(*auto_range(args.wlen_range,
                wlen_min, wlen_max,
                parser.get_default('wlen_range') or left_axis.get_xlim()))
        except ValueError as e:
            print(e)
            return -2

    # Set the flux axis range.
    if args.flux_range:
        try:
            if args.show_invalid:
                left_axis.set_ylim(*auto_range(args.flux_range,
                    np.min(flux_lo), np.max(flux_lo),
                    parser.get_default('flux_range') or left_axis.get_ylim(),
                    data_lo=flux_lo, data_hi=flux_lo))
            else:
                valid = (~flux_lo.mask & ~flux_hi.mask & (wlen_data > wlen_min) &
                    (wlen_data < wlen_max))
                left_axis.set_ylim(*auto_range(args.flux_range,
                    np.min(flux_lo[valid]), np.max(flux_hi[valid]),
                    parser.get_default('flux_range') or left_axis.get_ylim(),
                    data_lo=flux_lo[valid], data_hi=flux_hi[valid]))
        except ValueError as e:
            print(e)
            return -2

    # Set the wavelength dispersion axis range.
    if args.show_dispersion and args.wdisp_range:
        try:
            valid = ~wdisp_data.mask & (wlen_data > wlen_min) & (wlen_data < wlen_max)
            right_axis.set_ylim(*auto_range(args.wdisp_range,
                np.min(wdisp_data[valid]), np.max(wdisp_data[valid]),
                parser.get_default('wdisp_range') or right_axis.get_ylim(),
                wdisp_data[valid]))
        except ValueError as e:
            print(e)
            return -2

    # Draw mask lines that extend the full range of the left axis.
    if args.show_mask and len(x_mask) > 0:
        ymin, ymax = left_axis.get_ylim()
        assert len(x_mask) % 3 == 0, 'Internal error: len(x_mask)%3 != 0'
        y_mask = [ymin, ymax, None] * (len(x_mask) // 3)
        left_axis.plot(x_mask, y_mask, '-', color='DarkGreen', alpha=0.3)

    # Annotate the plot with a label identifying this spectrum.
    if args.label_pos != 'none':
        valign, halign = args.label_pos.split('-')
        x = dict(left=0.02, center=0.5, right=0.98)[halign]
        y = dict(top=0.98, center=0.5, bottom=0.02)[valign]
        plt.annotate(plot_label, xy=(x,y), xycoords='axes fraction',
            xytext=(x,y), textcoords='axes fraction',
            horizontalalignment=halign, verticalalignment=valign, fontsize=20)

    if not args.no_grid:
        left_axis.xaxis.grid()

    if args.save_plot is not None:
        save_name = args.save_plot
        if save_name == '':
            save_name = plot_label + '.png'
        try:
            figure.savefig(save_name)
        except IOError as e:
            print('Unable to save plot: {}.'.format(str(e)))
            return -1
        if args.verbose:
            print('Saved plot to {}'.format(save_name))
    if not args.no_display:
        plt.show()
    plt.close()

if __name__ == '__main__':
    main()
