#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ------------------------------------------------------------------------------
# 
# MIT License
# 
# Copyright (c) 2017 Gabriele Girelli
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# 
# ------------------------------------------------------------------------------

# ------------------------------------------------------------------------------
# 
# Author: Gabriele Girelli
# Email: gigi.ga90@gmail.com
# Date: 20180316
# Project: bioimaging
# Description: split TIFF in smaller images.
# 
# Changelog:
#  v1.0.0 - 20180316: converted from MATLAB to Python3 and merged with pygpseq.
# 
# ------------------------------------------------------------------------------



# DEPENDENCIES =================================================================

import argparse
import numpy as np
import os
import sys
from tqdm import tqdm

from ggc.prompt import ask

from pygpseq.tools import image as imt
from pygpseq.tools import plot
from pygpseq.tools.io import printout

# PARAMETERS ===================================================================

# Add script description
parser = argparse.ArgumentParser(description = '''
Split a TIFF image in smaller TIFF images of the specified side(s). If two
different sides are provided, the smaller images will be rectangular. The
first side corresponds to the X (columns) and the second to the Y (rows).
By default, only one side is required, which is used by the script for both
X and Y sides. I.e., square smaller images are produced.

If the original dimensions are not multiples of the specified side, a portion
of the image is lost, unless the --enlarge option is used. In that case,
the smaller images generated from the image border will contain empty pixels.

If the input image is a 3D stack, it will be split only on XY and the output
images will have the same number of slices. Using the -2 option, only the first
slice is split (i.e., the output will be in 2D).

By default, split images are generated left-to-right, top-to-bottom, e.g.,
1 2 3
4 5 6
7 8 9

Use the option -I to generate them top-to-bottom, left-to-right, e.g.,
1 4 7
2 5 8
3 6 9

Examples:

- Square images of 100x100 px
tiff_split big_image.tif split_out_dir 100 -e

- Rectangular images of 125x100 px
tiff_split big_image.tif split_out_dir 100 125 -e
''', formatter_class = argparse.RawDescriptionHelpFormatter)

# Add mandatory arguments
parser.add_argument('input', type = str,
    help = '''Path to the TIFF image to split.''')
parser.add_argument('outdir', type = str,
    help = '''Path to output TIFF folder, created if missing.''')
parser.add_argument('side', type = int, nargs = '+',
    help = '''One or two (XY) sides,
    used to specify the smaller images dimensions.''')

# Add flags
parser.add_argument('-e', '--enlarge',
    action = 'store_const', dest = 'enlarge',
    const = True, default = False,
    help = 'Expand to avoid pixel loss.')
parser.add_argument('-2', '--2d',
    action = 'store_const', dest = 'force2D',
    const = True, default = False,
    help = '''Enforce 2D split even if the image is a 3D stack.
    In that case the first slice will be split.''')
parser.add_argument('-I', '--invert',
    action = 'store_const', dest = 'inverted',
    const = True, default = False,
    help = '''Split top-to-bottom, left-to-right.''')
parser.add_argument('-y', '--do-all', action = 'store_const',
    help = """Do not ask for settings confirmation and proceed.""",
    const = True, default = False)

# Version flag
version = "1.1.0"
parser.add_argument('--version', action = 'version',
    version = '%s %s' % (sys.argv[0], version,))

# Parse arguments
args = parser.parse_args()

# Additional checks
assert os.path.isfile(args.input), "input file not found: %s" % args.input
assert not os.path.isfile(args.outdir), "output directory cannot be a file: %s" % (
    args.outdir)
if 1 == len(args.side):
    x_side = args.side[0]
    y_side = args.side[0]
else:
    x_side, y_side = args.side[:2]
if not os.path.isdir(args.outdir): os.mkdir(args.outdir)

# FUNCTIONS ====================================================================

def tiff_enlarge(img, x_side, y_side):
    '''Resize image to have dimensions that are multiple of side.

    Args:
        img (np.ndarray): image to resize.
        x_side (int): column side in px.
        y_side (int): row side in px.

    Returns:
        np.ndarray: resized image.
    '''

    N = len(img.shape)
    new_shape = [img.shape[i] for i in range(len(img.shape)) if i < N-2]

    reshaped = np.ceil(np.array([img.shape[-2]/y_side, img.shape[-1]/x_side]))
    reshaped *= np.array([y_side, x_side])
    reshaped = tuple([int(d) for d in reshaped])
    new_shape.extend(reshaped)

    new_img = img.copy()
    new_img = np.zeros(new_shape)
    
    new_img[np.ix_(*[range(img.shape[i]) for i in range(len(img.shape))])] = img
    return(new_img)

def calc_split_loss(img, x_side, y_side):
    '''Calculates how many pixels/voxels would be lost if image split is
    performed without enlarging.

    Args:
        img (np.ndarray): image to split.
        x_side (int): column side in px.
        y_side (int): row side in px.

    Returns:
        int: number of pixels/voxels lost during split.
    '''
    
    N = len(img.shape)

    # Identify loss per dimension
    missed = [img.shape[-1] % x_side, img.shape[-2] % y_side]
    
    # Calculate overall loss
    loss = []

    otherd = [img.shape[j] for j in range(len(img.shape)) if not N-1 == j]
    otherd.append(img.shape[-1] % x_side)
    loss.append(np.prod(otherd))

    otherd = [img.shape[j] for j in range(len(img.shape)) if not N-2 == j]
    otherd.append(img.shape[-2] % y_side)
    loss.append(np.prod(otherd))

    # Remove portions counted multiple times
    loss = np.sum(loss) - np.prod(img.shape[:-2]) * np.prod(missed)
    
    return(loss)

def tiff_split_2d(img, x_side, y_side, outdir, imgpath, inverted = False):
    '''Split 2D image in sub-images of x_side x y_side.
    Output is saved to outdir with the suffix .subN,
    where N is the sub-image index.

    Args:
        img (np.ndarray): image to split.
        x_side (int): column side in px.
        y_side (int): row side in px.
        outdir (str): path to output directory.
        imgpath (str): path to input image.
        inverted (bool): split top-to-bottom, left-to-right.
    '''
 
    # Get output file name parts
    prefix = os.path.splitext(os.path.basename(imgpath))[0]
    ext = os.path.splitext(os.path.basename(imgpath))[1]

    # Count cells to output
    n = img.shape[-1]//x_side * img.shape[-2]//y_side
    print("Output %d images." % n)
    if 0 == n: return

    # Iterate over sub image top-left corner positions
    ic = 1
    ys = [y for y in range(0,img.shape[-2],y_side) if y+y_side <= img.shape[-2]]
    xs = [x for x in range(0,img.shape[-1],x_side) if x+x_side <= img.shape[-1]]

    if inverted:
        print("Image split top-to-bottom, left-to-right.")
        xy_gen = ((x, y) for x in xs for y in ys)
    else:
        print("Image split left-to-right, top-to-bottom.")
        xy_gen = ((x, y) for y in ys for x in xs)

    with tqdm(range(n)) as pbar:
        for (x_start, y_start) in xy_gen:
            # Identify sub-image
            oimg = img[y_start:(y_start+y_side), x_start:(x_start+x_side)]

            # Output image
            opath = os.path.join(outdir, "%s.sub%d%s" % (prefix, ic, ext))
            plot.save_tif(opath, oimg, imt.get_dtype(oimg.max()), False)

            # Increase counter
            ic += 1
            pbar.update(1)

def tiff_split_3d(img, x_side, y_side, outdir, imgpath, inverted = False):
    '''Split 2D image in sub-images of x_side x y_side x stack_depth.
    Output is saved to outdir with the suffix .subN,
    where N is the sub-image index.

    Args:
        img (np.ndarray): image to split.
        x_side (int): column side in px.
        y_side (int): row side in px.
        outdir (str): path to output directory.
        imgpath (str): path to input image.
        inverted (bool): split top-to-bottom, left-to-right.
    '''
 
    # Get output file name parts
    prefix = os.path.splitext(os.path.basename(imgpath))[0]
    ext = os.path.splitext(os.path.basename(imgpath))[1]

    # Count cells to output
    n = img.shape[-1]//x_side * img.shape[-2]//y_side
    print("Output %d images." % n)
    if 0 == n: return

    # Iterate over sub image top-left corner positions
    ic = 1
    ys = [y for y in range(0,img.shape[-2],y_side) if y+y_side <= img.shape[-2]]
    xs = [x for x in range(0,img.shape[-1],x_side) if x+x_side <= img.shape[-1]]

    if inverted:
        print("Image split top-to-bottom, left-to-right.")
        xy_gen = ((x, y) for x in xs for y in ys)
    else:
        print("Image split left-to-right, top-to-bottom.")
        xy_gen = ((x, y) for y in ys for x in xs)

    with tqdm(range(n)) as pbar:
        for (x_start, y_start) in xy_gen:
                # Identify sub-image
                oimg = img[:,
                    y_start:(y_start + y_side),
                    x_start:(x_start + x_side)
                ]

                # Output image
                opath = os.path.join(
                    outdir, "%s.sub%d%s" % (prefix, ic, ext))
                plot.save_tif(
                    opath, oimg, imt.get_dtype(oimg.max()), False)

                # Increase counter
                ic += 1
                pbar.update(1)

def print_settings(args, clear = True):
    '''Show input settings, for confirmation.

    Args:
        args (Namespace): arguments parsed by argparse.
        clear (bool): clear screen before printing.
    '''
    
    s = " # TIFF split v%s\n" % version
    s += """
        Input file :  %s
  Output directory :  %s

            X side : %d
            Y side : %d

           Enlarge : %r
        Enforce 2D : %r
          Inverted : %r
    """ % (args.input, args.outdir, x_side, y_side,
        args.enlarge, args.force2D, args.inverted)

    if clear: print("\033[H\033[J%s" % s)
    else: print(s)
    return(s)

# RUN ==========================================================================

# Show current settings
settings_string = print_settings(args)
if not args.do_all: ask("Confirm settings and proceed?")

# Read input image
img = imt.read_tiff(args.input, k = 3)

# Check image shape and select appropriate analysis method ---------------------

status = np.nan
if 3 == len(img.shape):
    print("3D stack found: %s" % str(img.shape))
    if args.force2D:
        print("Enforcing 2D split (extracting 1st slice only).")
        status = "2D"
        umes = "pixel"
        img = img[0, :, :].copy()
    else:
        status = "3D"
        umes = "voxel"
elif 2 == len(img.shape):
    print("2D image found: %s" % str(img.shape))
    status = "2D"
    umes = "pixel"
else:
    printout("Cannot split a 1D image. File: %s" % args.input, -2)

# Enlarge or calculate pixel loss ----------------------------------------------

if args.enlarge:
    img = tiff_enlarge(img, x_side, y_side)
    print("Image enlarged to %s" % str(img.shape))
else:
    loss = calc_split_loss(img, x_side, y_side)
    print("%d %ss lost (%.2f%%). Use -e to avoid loss." % (
        loss, umes, loss / np.prod(img.shape) * 100))

# Split image ------------------------------------------------------------------

if "2D" == status:
    tiff_split_2d(img, x_side, y_side, args.outdir, args.input, args.inverted)
elif "3D" == status:
    tiff_split_3d(img, x_side, y_side, args.outdir, args.input, args.inverted)
else:
    printout("Unrecognized analysis mode.", -2)

# END ==========================================================================

print("DONE")

################################################################################
