#!python
"""
Create one volume for each ROI in the given NifTI file.
"""
#------------------------------------------------------------------------------
# Author: Alexandre Manhaes Savio <alexsavio@gmail.com>
# Nuklear Medizin Department
# Klinikum rechts der Isar der Technische Universitaet Muenchen, Deutschland
#
# 2016, Alexandre Manhaes Savio
# Use this at your own risk!
#------------------------------------------------------------------------------
import os.path as op
import argparse

import nibabel as nib

from boyle.files.names import get_extension
from boyle.nifti.read import read_img
from boyle.nifti.roi import get_roilist_from_atlas


def split_rois(img):
    """
    Parameters
    ----------
    img: nifti-like object or str

    Returns
    -------
    drained_img: nibabel.NiftiImage1

    roi_num: int
    """
    img = read_img(img)

    rois = get_roilist_from_atlas(img)
    vol  = img.get_data()

    for r in rois:
        yield r, vol * (vol == r)


def set_parser():
    parser = argparse.ArgumentParser(description='Create one volume for each ROI in the given NifTI file.')
    parser.add_argument('-i', '--in', dest='input',
                        required=True, help='Input file path.')
    parser.add_argument('-o', '--out', dest='output',
                        required=True, help='Output file base name.')
    # TODO
    #parser.add_argument('-r', '--rois', dest='rois_file',
    #                    help='Path to a text file with a "\\n" separated list of the ROI values'
    #                         ' of interest.')
    return parser


if __name__ == "__main__":

    parser = set_parser()

    try:
        args = parser.parse_args()
    except argparse.ArgumentError as exc:
        parser.error(str(exc))
        exit(-1)

    out_path = op.join(op.dirname(args.input), args.output)
    img = read_img(args.input)
    ext = get_extension(args.input)

    for r, vol in split_rois(args.input):
        roiimg = nib.Nifti1Image(vol, affine=img.affine, header=img.header)
        roiimg.to_filename(args.output + '_roi_{}{}'.format(str(r).zfill(3), ext))
