#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from os.path import join
import argparse

import nibabel as nib
import numpy as np
from nibabel import trackvis
from tqdm import tqdm

from tractseg.libs import tractometry
from tractseg.data import dataset_specific_utils


def main():
    parser = argparse.ArgumentParser(description="Evaluate image (e.g. FA) along fiber bundles.",
                                     epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                            "TractSeg - Fast and accurate white matter tract segmentation. "
                                            "https://doi.org/10.1016/j.neuroimage.2018.07.070)'")
    parser.add_argument("-i", metavar="tracking_dir", dest="tracking_dir",
                        help="Folder containing the TractSeg tractograms (normally '.../tractseg_output/TOM_tracking')",
                        required=True)

    parser.add_argument("-o", metavar="csv_output", dest="csv_file_out",
                        help="CSV output file containing the results", required=True)

    parser.add_argument("-e", metavar="endings_dir", dest="endings_dir",
                        help="Folder containing the TractSeg bundle endings segmentations "
                             "(normally '.../tractseg_output/endings_segmentations'). "
                             "Needed to ensure that all fibers are starting from the same side.", required=True)

    parser.add_argument("-s", metavar="scalar_img", dest="scalar_img",
                        help="Scalar image (e.g. FA) or peak image (MRtrix peaks) if using '--peak_length'",
                        required=True)

    parser.add_argument("--nr_points", metavar="n", dest="nr_points",
                        help="Number of points along streamline to evaluate (default: 100)",
                        default="100")

    parser.add_argument("--peak_length", action="store_true",
                        help="Instead of taking values of scalar image along streamlines (e.g. FA) take the length "
                             "of the peak pointing in the same direction as the respective bundle. Gives better "
                             "results in areas of crossing fibers than FA.",
                        default=False)

    parser.add_argument("--TOM", metavar="TOM_dir", dest="TOM_dir",
                        help="Folder containing Tract Orientation Maps (TOMs) (normally '.../tractseg_output/TOM'). "
                             "Needed if using the '--peak_length' option.")

    parser.add_argument("--tracking_format", metavar="tck|trk|trk_legacy", choices=["tck", "trk", "trk_legacy"],
                        help="Set output format of tracking. For trk also the option trk_legacy is available. This "
                             "uses the older trk convention (streamlines are stored in coordinate space and affine is "
                             "not applied. See nibabel.trackvis.read. (default: trk_legacy)",
                        default="trk_legacy")

    parser.add_argument("--test", metavar="1|2|3", choices=[0, 1, 2, 3], type=int,
                        help="Only needed for unittesting.",
                        default=0)

    args = parser.parse_args()

    NR_POINTS = int(args.nr_points)
    # Dilation >0 important because otherwise some streamlines do not start/end in beginnings region and then
    # correct reorientation/flipping of streamlines does not work anymore
    DILATION = 2
    scalar_image = nib.load(args.scalar_img)

    if args.test == 1:
        bundles = dataset_specific_utils.get_bundle_names("test")[1:]
    elif args.test == 2:
        bundles = dataset_specific_utils.get_bundle_names("toy")[1:]
        DILATION = 0
    elif args.test == 3:
        bundles = dataset_specific_utils.get_bundle_names("test_single")[1:]
    else:
        bundles = dataset_specific_utils.get_bundle_names("All_tractometry")[1:]

    results = []
    for bundle in tqdm(bundles):
        if args.peak_length:
            predicted_peaks = nib.load(join(args.TOM_dir, bundle + ".nii.gz")).get_fdata()
        else:
            predicted_peaks = None
        beginnings = nib.load(join(args.endings_dir, bundle + "_b.nii.gz"))

        file_ending = "trk" if args.tracking_format == "trk_legacy" else args.tracking_format
        trk_path = join(args.tracking_dir, bundle + "." + file_ending)

        if not os.path.exists(trk_path):
            print("WARNING: No tracking found for bundle {}. Returning zeros.".format(bundle))
            mean = np.zeros(NR_POINTS)
            std = np.zeros(NR_POINTS)
        else:
            if args.tracking_format == "trk_legacy":
                streams, hdr = trackvis.read(trk_path)
                streamlines = [s[0] for s in streams]
            else:
                sl_file = nib.streamlines.load(trk_path)
                streamlines = sl_file.streamlines

            if len(streamlines) >= 5 or args.test == 2:
                mean, std = tractometry.evaluate_along_streamlines(np.nan_to_num(scalar_image.get_fdata()), streamlines,
                                                               beginnings.get_fdata(), NR_POINTS, dilate=DILATION,
                                                               predicted_peaks=predicted_peaks, affine=scalar_image.affine)
            else:
                print("WARNING: bundle {} contains less than 5 streamlines. Saving value 0 for this bundle.".
                      format(bundle))
                mean = np.zeros(NR_POINTS)
                std = np.zeros(NR_POINTS)

        # Remove first and last segment as those tend to be more noisy
        mean = mean[1:-1]
        std = std[1:-1]

        results.append(mean)

    bundle_string = ""
    for bundle in bundles:
        bundle_string += bundle + ";"
    bundle_string = bundle_string[:-1]

    np.savetxt(args.csv_file_out, np.array(results).transpose(), delimiter=";", header=bundle_string, comments="")

    # Notes on reproducibility
    # - map_coordinates, QuickBundles and cKDTree are deterministic for the same input streamlines
    # - Variance in final Tractometry results when running 2 probabilistic trackings (10k fibers and 100 points):
    #   - Max difference ~0.01, but on average difference around 0.004
    #   - When plotting the differences: very minor
    #   - Will get averaged out in group analysis


if __name__ == '__main__':
    main()