#!python
# -*- coding: utf-8 -*-

import argparse
import logging
import os

# Set  OMP_NUM_THREADS to 1 before importing numpy
os.environ["OMP_NUM_THREADS"] = "1"

import nibabel as nib
from math import pi
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

import bonndit as bd
from bonndit.shoredeconv import fa_guided_mask
from bonndit.io import fsl_vectors_to_worldspace, fsl_gtab_to_worldspace
from bonndit.michi import fields, dwmri


def main():
    parser = argparse.ArgumentParser(
        description='This script computes fiber orientation distribution '
                    'functions (fODFs) as described in "Versatile, Robust and '
                    'Efficient Tractography With Constrained Higher Order '
                    'Tensor fODFs" by Ankele et al. (2017). It is assumed that'
                    ' the input data is saved using FSL.', add_help=False)

    parser.add_argument('indir',
                        help='Folder containing all required input files')

    parser.add_argument('-o', '--outdir',
                        help='Folder in which the output will be saved (default: same as indir)')

    inputfiles = parser.add_argument_group('Custom input filenames', 'It is not recommended to specify \
    custom names for input files.')
    inputfiles.add_argument('-d', '--data', default='data.nii.gz',
                            help='Diffusion weighted data (default: data.nii.gz)')
    inputfiles.add_argument('-e', '--dtivecs', default='dti_V1.nii.gz',
                            help='First eigenvectors of a DTI model (default: dti_V1.nii.gz)')
    inputfiles.add_argument('-a', '--dtifa', default='dti_FA.nii.gz',
                            help='Fractional anisotropy values from a DTI model (default: dti_FA.nii.gz)')
    inputfiles.add_argument('-m', '--brainmask', default='mask.nii.gz',
                            help='Brain mask (default: mask.nii.gz)')
    inputfiles.add_argument('-W', '--wmmask', default='fast_pve_2.nii.gz',
                            help='White matter mask (default: fast_pve_2.nii.gz)')
    inputfiles.add_argument('-G', '--gmmask', default='fast_pve_1.nii.gz',
                            help='Gray matter mask (default: fast_pve_1.nii.gz)')
    inputfiles.add_argument('-F', '--csfmask', default='fast_pve_0.nii.gz',
                            help='Cerebrospinal fluid mask (default: fast_pve_0.nii.gz)')

    flags = parser.add_argument_group('flags (optional)', '')
    flags.add_argument("-h", "--help", action="help", help="Show this help message and exit")
    flags.add_argument('-v', '--verbose', action='store_true',
                       help='Activate progress bars and console logging')
    flags.add_argument('-R', '--responseonly', action='store_true',
                       help='Calculate and save only the response functions')
    flags.add_argument('-M', '--tissuemasks', action='store_true',
                       help='Output the DTI improved tissue masks (csf/gm/wm)')

    shoreopts = parser.add_argument_group('shore options (optional)', 'Optional arguments for the computation of \
    the shore response functions')
    shoreopts.add_argument('-k', '--kernel', choices=["rank1", "delta"],
                           default="rank1", type=str,
                           help='Kernel type (default: rank1)')
    shoreopts.add_argument('-r', '--order', default=4, type=int,
                           help='Order of the shore basis (default: 4)')
    shoreopts.add_argument('-z', '--zeta', default=700, type=float,
                           help='Radial scaling factor (default: 700)')
    shoreopts.add_argument('-t', '--tau', default=1 / (4 * pi ** 2),
                           type=float,
                           help='q-scaling (default: 1 / (4 * math.pi ** 2)')
    shoreopts.add_argument('-f', '--fawm', default=0.7, type=float,
                           help='White matter fractional anisotropy threshold (default: 0.7)')

    deconvopts = parser.add_argument_group('deconvolution options (optional)', '')
    deconvopts.add_argument('-C', '--constraint', choices=['hpsd', 'nonneg', 'none'], default='hpsd',
                            help='Constraint for the fODFs (default: hpsd)')

    multiprocessing = parser.add_argument_group('multiprocessing (optional)', 'Configure the multiprocessing behaviour \
    (only supported for Python 3)')
    multiprocessing.add_argument('-w', '--workers', default=None, type=int,
                                 help='Number of cpus (default: all available cpus)')

    log = parser.add_argument_group('logging (optional)', 'Configure the logging behaviour')
    log.add_argument('-L', '--loglevel', choices=['INFO', 'WARNING', 'ERROR'],
                     default='INFO',
                     help='Specify the logging level for the console')

    args = parser.parse_args()

    # Create outdir if it does not exists
    indir = args.indir
    if not args.outdir:
        outdir = indir
    else:
        outdir = args.outdir

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    levels = {'INFO': logging.INFO,
              'WARNING': logging.WARNING,
              'ERROR': logging.ERROR}

    # Logging setup for file
    logging.basicConfig(filename=os.path.join(outdir, 'mtdeconv.log'),
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%y-%m-%d %H:%M',
                        level=levels[args.loglevel],
                        filemode='w')

    # Console logging if verbose flag is set
    if args.verbose:
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(levels[args.loglevel])
        # set a format which is simpler for console use
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)

    logging.info('mtdeconv has been called with:')
    param_string = 'Order: {} Zeta: {}, Tau: {}, FAWM: {}, Constraint: {}, ' \
                   'Kernel {}'
    logging.info(param_string.format(args.order, args.zeta, args.tau,
                                     args.fawm, args.constraint, args.kernel))

    # If deconvolution is wanted check whether a response exists already
    response_available = False
    if not args.responseonly:
        if os.path.exists(os.path.join(outdir, 'response.npz')):
            response_available = True

    # If response available load data and response
    if response_available:
        # Only mask, and data are needed. Gradients are saved in response.
        for f in [args.brainmask, 'bvals', 'bvecs', ]:
            filepath = os.path.join(indir, f)
            if not os.path.isfile(filepath):
                msg = 'No such file or directory: "{}"'.format(filepath)
                logging.error(msg)
                raise FileNotFoundError(msg)

        # Load existing response
        fit = bd.ShoreMultiTissueResponse.load(
            os.path.join(outdir, "response.npz"))
        logging.info('Existing response functions loaded.')

        # Load DTI mask
        dti_mask = bd.load(os.path.join(indir, args.brainmask))

        # Load diffusion weighted data
        data = bd.load(os.path.join(indir, args.data))

        logging.info('Input loaded.')

    # If response not available load files and compute response
    else:
        # Test whether all required files are available
        for f in [args.brainmask, args.dtifa, args.csfmask, args.gmmask,
                  args.wmmask, args.dtivecs, 'bvals', 'bvecs', ]:
            filepath = os.path.join(indir, f)
            if not os.path.isfile(filepath):
                msg = 'No such file or directory: "{}"'.format(filepath)
                logging.error(msg)
                raise FileNotFoundError(msg)

        # Load fractional anisotropy
        dti_fa = bd.load(os.path.join(indir, args.dtifa))

        # Load DTI mask
        dti_mask = bd.load(os.path.join(indir, args.brainmask))

        # Load tissue segmentation masks
        csf_mask = bd.load(os.path.join(indir, args.csfmask))
        gm_mask = bd.load(os.path.join(indir, args.gmmask))
        wm_mask = bd.load(os.path.join(indir, args.wmmask))

        # Load first eigenvectors of a precalculated diffusion tensor
        dti_vecs = bd.load(os.path.join(indir, args.dtivecs))

        # Load diffusion weighted data
        data = bd.load(os.path.join(indir, args.data))

        # Load bvals and bvecs and initialize a GradientTable object
        bvals, bvecs = read_bvals_bvecs(os.path.join(indir, "bvals"),
                                        os.path.join(indir, "bvecs"))
        gtab = gradient_table(bvals, bvecs)

        logging.info('Input loaded.')

        # Create tissue masks based on fractional anisotropy values
        wm_mask = fa_guided_mask(wm_mask, dti_fa, dti_mask,
                                 tissue_threshold=0.95,
                                 fa_lower_thresh=0.7)
        gm_mask = fa_guided_mask(gm_mask, dti_fa, dti_mask,
                                 tissue_threshold=0.95,
                                 fa_upper_thresh=0.2)
        csf_mask = fa_guided_mask(csf_mask, dti_fa, dti_mask,
                                  tissue_threshold=0.95,
                                  fa_upper_thresh=0.2)

        if args.tissuemasks:
            nib.save(wm_mask, os.path.join(outdir, 'wm_mask.nii.gz'))
            nib.save(gm_mask, os.path.join(outdir, 'gm_mask.nii.gz'))
            nib.save(csf_mask, os.path.join(outdir, 'csf_mask.nii.gz'))

        logging.info('Fractional anisotropy based tissue masks created.')

        # Flip sign of x-coordinate if affine determinant is positive and rotate to worldspace
        gtab = fsl_gtab_to_worldspace(gtab, data.affine)
        dti_vecs = fsl_vectors_to_worldspace(dti_vecs)
        logging.info('Rotation to worldspace finished')

        model = bd.ShoreMultiTissueResponseEstimator(gtab, args.order,
                                                     args.zeta, args.tau)
        fit = model.fit(data, dti_vecs, wm_mask, gm_mask, csf_mask,
                        verbose=args.verbose, cpus=args.workers)
        fit.save(os.path.join(outdir, "response.npz"))
        logging.info('Response functions estimated and saved.')

    # We need this Meta object for saving later
    base_filename = os.path.join(indir, args.data).rstrip(".gz").rstrip(".nii")
    try:
        _, _, meta = dwmri.load(base_filename + '.nii.gz')
    except FileNotFoundError:
        try:
            _, _, meta = dwmri.load(base_filename + '.nii')
        except FileNotFoundError as e:
            raise FileNotFoundError(e)

    print('test')
    # Deconvolution if 'responseonly' is not set
    if not args.responseonly:
        out, wmout, gmout, csfout = fit.fodf(data, pos=args.constraint,
                                             mask=dti_mask, kernel=args.kernel,
                                             verbose=args.verbose,
                                             cpus=args.workers)
        logging.info('Signal deconvolved with multiple tissue response functions.')

        fields.save_tensor(os.path.join(outdir, "fodf.nrrd"), out,
                           mask=dti_mask.get_data(), meta=meta)
        logging.info('fODFs saved.')

        # Save volumes
        fields.save_scalar(os.path.join(outdir, "wmvolume.nrrd"),
                           wmout, meta)
        fields.save_scalar(os.path.join(outdir, "gmvolume.nrrd"),
                           gmout, meta)
        fields.save_scalar(os.path.join(outdir, "csfvolume.nrrd"),
                           csfout, meta)
        logging.info('Volume fractions saved.')

    logging.info('Success!')


if __name__ == "__main__":
    main()
