#!python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

"""Fetch BOSS data files containing the spectra of specified observations and mirror them locally.
"""

from __future__ import division,print_function

import os.path
import multiprocessing

from astropy.utils.compat import argparse

from progressbar import ProgressBar,Percentage,Bar

import astropy.table

import bossdata.path
import bossdata.remote

def fetch(remote_paths,response_queue):
    mirror = bossdata.remote.Manager()
    for remote_path in remote_paths:
        try:
            local_path = mirror.get(remote_path)
            response_queue.put((os.path.getsize(local_path),))
        except RuntimeError as e:
            response_queue.put((0,remote_path,str(e)))

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action='store_true',
        help='Provide verbose output.')
    parser.add_argument('observations', type=str, default=None, metavar='FILE',
        help='File containing PLATE,MJD,FIBER columns that specify the observations to fetch.')
    parser.add_argument('--full', action='store_true',
        help='Fetch the full version of each spectrum data file.')
    parser.add_argument('--nproc', type=int, default=2,
        help='Number of subprocesses to use to parallelize downloads (1-5).')
    args = parser.parse_args()

    if args.nproc < 1 or args.nproc > 5:
        print('nproc must be 1-5.')
        return -1

    # Read the list of observations to fetch.
    root,ext = os.path.splitext(args.observations)
    if ext in ('.dat','.txt'):
        input_format = 'ascii'
    else:
        input_format = None
    table = astropy.table.Table.read(args.observations,format=input_format)
    num_obs = len(table)

    if args.verbose:
        print('Fetching {:d} observations...'.format(num_obs))
        progress_bar = ProgressBar(widgets = [Percentage(),Bar()],maxval = num_obs).start()
        num_bytes = 0

    try:
        finder = bossdata.path.Finder()
        mirror = bossdata.remote.Manager()
    except ValueError as e:
        print(e)
        return -1

    # Build a list of remote paths from the input plate-mjd-fiber values.
    remote_paths = []
    for row in table:
        remote_paths.append(finder.get_spec_path(
            plate = row['PLATE'],mjd = row['MJD'],fiber = row['FIBER'],lite = not args.full))

    # Initialize a queue that subprocesses use to signal their progress.
    response_queue = multiprocessing.Queue()

    # Launch subprocesses to handle subsets of remote paths.
    chunk_size = (len(remote_paths) + args.nproc - 1)//args.nproc
    processes = []
    for i in range(args.nproc):
        # The last chunk will be shorter if the number of paths does not evenly divide
        # between the subprocesses.
        chunk = remote_paths[i*chunk_size:(i+1)*chunk_size]
        process = multiprocessing.Process(target=fetch, args=(chunk, response_queue))
        processes.append(process)
        process.start()

    # Monitor subprocess progress.
    num_fetched = 0
    try:
        while num_fetched < len(remote_paths):
            response = response_queue.get()
            if response[0] == 0:
                print('Download error for {file}:\n{msg}'.format(file=response[1], msg=response[2]))
            else:
                num_bytes += response[0]
            num_fetched += 1
            if args.verbose:
                progress_bar.update(num_fetched)
        if args.verbose:
            progress_bar.finish()
        # Give subprocesses a chance to finish normally.
        for process in processes:
            process.join(timeout=1)
    except KeyboardInterrupt:
        print('Stopping after keyboard interrupt.')

    # Ensure that all subprocesses have terminated. This should never be necessary
    # after normal completion.
    for process in processes:
        if process.is_alive():
            print('Killing subprocess {}.'.format(process.name))
            process.terminate()

    if args.verbose:
        print('Processed {:.1f} Mb of data files for {:d} observations.'.format(
            num_bytes/float(1<<20),num_fetched))

    if num_fetched != num_obs:
        print('WARNING: {:d} of {:d} observations were not fetched.'.format(
            num_obs-num_fetched,num_obs),'Re-run the command after any problems are fixed.')

if __name__ == '__main__':
    main()
