#!python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

""" Plot a single BOSS spectrum.
"""

from __future__ import division,print_function

from astropy.utils.compat import argparse

import numpy as np
import numpy.ma
import matplotlib.pyplot as plt

import bossdata.path
import bossdata.remote
import bossdata.spec
import bossdata.bits

def print_mask_summary(mask_values):
    if np.any(mask_values):
        print('Pixel mask summary:')
        bit_summary = bossdata.bits.summarize_bitmask_values(
            bossdata.bits.SPPIXMASK,mask_values)
        for bit_name,bit_count in bit_summary.iteritems():
            print('{0:5d} {1}'.format(bit_count,bit_name))
    else:
        print('No pixels masked.')

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter,
        description = 'Plot a single BOSS spectrum.')
    parser.add_argument('--verbose', action = 'store_true',
        help = 'Provide verbose output.')
    parser.add_argument('--plate',type = int, default = 6641, metavar = 'PLATE',
        help = 'Plate number of spectrum to plot.')
    parser.add_argument('--mjd',type = int, default = 56383, metavar = 'MJD',
        help = 'Modified Julian date of plate observation to use.')
    parser.add_argument('--fiber',type = int,default = 30, metavar = 'FIBER',
        help = 'Fiber number identifying the spectrum of the requested PLATE-MJD to plot.')
    parser.add_argument('--exposure',type = int,default = None, metavar = 'EXP',
        help = 'Exposure sequence number starting from 0, or plot the coadd if not set.')
    parser.add_argument('--camera',type = str, choices = ['blue','red','both'], default = 'both',
        help = 'Camera to use when plotting a single exposure.')
    parser.add_argument('--allow-mask', type = str, default = None,
        help = 'SPPIXMASK bit names to allow in valid data. Separate multiple names with |.')
    parser.add_argument('--save-plot', type = str, default = None, metavar = 'FILE',
        help = 'File name to save the generated plot to.')
    parser.add_argument('--no-display', action = 'store_true',
        help = 'Do not display the image on screen (useful for batch processing).')
    parser.add_argument('--scatter', action = 'store_true',
        help = 'Show scatter of flux instead of a flux error band.')
    parser.add_argument('--show-mask', action = 'store_true',
        help = 'Indicate pixels with invalid data using vertical lines.')
    parser.add_argument('--show-dispersion', action = 'store_true',
        help = 'Show the wavelength dispersion using the right-hand axis.')
    parser.add_argument('--show-sky', action = 'store_true',
        help = 'Show the subtracted sky flux instead of the object flux.')
    parser.add_argument('--add-sky', action = 'store_true',
        help = 'Add the subtracted sky to the object flux (overrides show-sky).')
    args = parser.parse_args()

    if args.exposure is None and args.camera is not 'both':
        print('Ignoring camera = "{0}" for coadded spectrum.'.format(args.camera))
        args.camera = 'both'

    if args.allow_mask is None:
        pixel_quality_mask = None
    else:
        pixel_quality_mask = bossdata.bits.bitmask_from_text(
            bossdata.bits.SPPIXMASK,args.allow_mask)

    try:
        finder = bossdata.path.Finder()
        mirror = bossdata.remote.Manager()
    except ValueError as e:
        print(e)
        return -1

    remote_path = finder.get_spec_path(plate = args.plate,mjd = args.mjd,fiber = args.fiber,
        lite = (args.exposure is None))
    local_path = mirror.get(remote_path)
    specfile = bossdata.spec.SpecFile(local_path)

    # Initialize the plot.
    figure = plt.figure(figsize=(12,8))
    left_axis = plt.gca()
    figure.set_facecolor('white')
    plt.xlabel('Wavelength (Angstrom)')
    left_axis.set_ylabel('Flux (1e-17 erg/s/cm**2)')
    if args.show_dispersion:
        right_axis = left_axis.twinx()
        right_axis.set_ylabel('Dispersion (Angstrom)')

    # We will potentially plot two spectra.
    spectra = [ ]
    plot_colors = [ ]
    data_args = dict(include_wdisp=args.show_dispersion, include_sky=args.show_sky or args.add_sky)
    if args.exposure is None:
        spectra.append(specfile.get_valid_data(pixel_quality_mask=pixel_quality_mask, **data_args))
        plot_colors.append('black')
        if args.verbose:
            print('Showing coadd of {0:d} exposures:'.format(
                len(specfile.exposure_sequence)),(','.join(specfile.exposure_sequence)))
            print_mask_summary(specfile.get_pixel_mask())
    else:
        if args.verbose:
            print('Showing exposure',specfile.exposure_sequence[args.exposure])
        if args.camera in ('blue','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'b',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary(specfile.get_pixel_mask(args.exposure,'b'))
        if args.camera in ('red','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'r',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary(specfile.get_pixel_mask(args.exposure,'r'))

    wlen_min,wlen_max = 20000.,0.
    for data,plot_color in zip(spectra,plot_colors):

        wlen,dflux = data['wavelength'][:],data['dflux'][:]
        if args.add_sky:
            flux = data['sky'][:] + data['flux'][:]
        elif args.show_sky:
            flux = data['sky'][:]
        else:
            flux = data['flux'][:]

        if args.scatter:
            left_axis.scatter(wlen,flux,color=plot_color,marker='.',s=0.1)
        else:
            left_axis.fill_between(wlen,flux-dflux,flux+dflux,color=plot_color,alpha=0.5)

        num_masked = len(data.mask)
        if args.show_mask and num_masked > 0:
            x_mask = [ ]
            y_mask = [ ]
            ymin,ymax = left_axis.get_ylim()
            bad_pixels = np.where(data.mask)
            for x in data.data['wavelength'][bad_pixels]:
                x_mask.extend([x,x,None])
                y_mask.extend([ymin,ymax,None])
            plt.plot(x_mask,y_mask,'-',color=plot_color,alpha=0.2)

        if args.show_dispersion:
            right_axis.plot(wlen,data['wdisp'][:],ls='-',color=plot_color)

        # Update the plot wavelength limits to include this data.
        wlen_min = min(wlen_min,np.ma.min(wlen))
        wlen_max = max(wlen_max,np.ma.max(wlen))

    # The x-axis limits are reset by the twinx() function so we set them here.
    plt.xlim(wlen_min,wlen_max)

    if args.save_plot:
        figure.savefig(args.save_plot)
    if not args.no_display:
        plt.show()
    plt.close()

if __name__ == '__main__':
    main()
