#!/usr/bin/env python

import argparse
import math
from decimal import Decimal

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
import pandas as pd
from tqdm import tqdm

from tractseg.data import dataset_specific_utils
from tractseg.libs.AFQ_MultiCompCorrection import AFQ_MultiCompCorrection
from tractseg.libs.AFQ_MultiCompCorrection import get_significant_areas
from tractseg.libs import metric_utils


def parse_subjects_file(file_path):
    with open(file_path) as f:
        l = f.readline().strip()
        if l.startswith("# tractometry_path="):
            base_path = l.split("=")[1]
        else:
            raise ValueError("Invalid first line in subjects file. Must start with '# tractometry_path='")

        # parse bundle names
        l = f.readline().strip()
        if l.startswith("# bundles="):
            bundles_string = l.split("=")[1]
            bundles = bundles_string.split(" ")

            valid_bundles = dataset_specific_utils.get_bundle_names("All_tractometry")[1:]
            for bundle in bundles:
                if bundle not in valid_bundles:
                    raise ValueError("Invalid bundle name: {}".format(bundle))

            print("Using {} manually specified bundles.".format(len(bundles)))
        else:
            bundles = dataset_specific_utils.get_bundle_names("All_tractometry")[1:]

    df = pd.read_csv(file_path, sep=" ", comment="#")
    df["subject_id"] = df["subject_id"].astype(str)

    # Check that each column (except for first one) is correctly parsed as a number
    for col in df.columns[1:]:
        if not np.issubdtype(df[col].dtype, np.number):
            raise IOError("Column {} contains non-numeric values".format(col))

    if df.columns[1] == "group":
        if df["group"].max() > 1:
            raise IOError("Column 'group' may only contain 0 and 1.")

    return base_path, df, bundles


def correct_for_confounds(values, meta_data, bundles, selected_bun_indices, NR_POINTS, analysis_type, confound_names):
    values_cor = np.zeros([len(bundles), NR_POINTS, len(meta_data)])
    for b_idx in selected_bun_indices:
        for jdx in range(NR_POINTS):
            target = np.array([values[s][b_idx][jdx] for s in meta_data["subject_id"]])
            if analysis_type == "group":
                target_cor = metric_utils.unconfound(target, meta_data[["group"] + confound_names].values,
                                                     group_data=True)
            else:
                target_cor = metric_utils.unconfound(target, meta_data[confound_names].values,
                                                     group_data=False)
                meta_data["target"] = metric_utils.unconfound(meta_data["target"].values,
                                                              meta_data[confound_names].values,
                                                              group_data=False)
            values_cor[b_idx, jdx, :] = target_cor

    # Restore original data structure
    values_cor = values_cor.transpose(2, 0, 1)
    # todo: nicer way: use numpy array right from beginning instead of dict
    values_cor_dict = {}
    for idx, subject in enumerate(list(meta_data["subject_id"])):
        values_cor_dict[subject] = values_cor[idx]
    return values_cor_dict


def get_corrected_alpha(values_allp, meta_data, analysis_type, subjects_A, subjects_B, alpha, bundles, nperm, b_idx):
    if analysis_type == "group":
        y = np.array((0,) * len(subjects_A) + (1,) * len(subjects_B))
    else:
        y = meta_data["target"].values
    alphaFWE, statFWE, clusterFWE, stats = AFQ_MultiCompCorrection(np.array(values_allp), y,
                                                                   alpha, nperm=nperm)
    # print("Processing {}...".format(bundles[b_idx]))
    # print("  cluster size: {}".format(clusterFWE))
    # print("  alphaFWE: {}".format(format_number(alphaFWE)))
    return alphaFWE, clusterFWE


def format_number(num):
    if num > 0.00001:
        return round(num, 6)
    else:
        return '%.2e' % Decimal(num)


def plot_tractometry_with_pvalue(values, meta_data, bundles, selected_bundles, output_path, alpha, FWE_method,
                                 analysis_type, correct_mult_tract_comp, show_detailed_p, nperm=1000):

    NR_POINTS = values[meta_data["subject_id"][0]].shape[1]
    selected_bun_indices = [bundles.index(b) for b in selected_bundles]

    if analysis_type == "group":
        subjects_A = list(meta_data[meta_data["group"] == 0]["subject_id"])
        subjects_B = list(meta_data[meta_data["group"] == 1]["subject_id"])
    else:
        subjects_A = list(meta_data["subject_id"])
        subjects_B = []

    confound_names = list(meta_data.columns[2:])

    cols = 5
    rows = math.ceil(len(selected_bundles) / cols)

    a4_dims = (cols*3, rows*5)
    f, axes = plt.subplots(rows, cols, figsize=a4_dims)

    axes = axes.flatten()
    sns.set(font_scale=1.2)
    sns.set_style("whitegrid")

    # Correct for confounds
    values = correct_for_confounds(values, meta_data, bundles, selected_bun_indices, NR_POINTS, analysis_type,
                                   confound_names)

    # Significance testing with multiple correction of bundles
    if correct_mult_tract_comp:
        values_allp = []  # [subjects, NR_POINTS * nr_bundles]
        for s in meta_data["subject_id"]:
            values_subject = []
            for i, b_idx in enumerate(selected_bun_indices):
                values_subject += list(values[s][b_idx]) # concatenate all bundles
            values_allp.append(values_subject)
        alphaFWE, clusterFWE = get_corrected_alpha(values_allp, meta_data, analysis_type, subjects_A, subjects_B, alpha,
                                                   bundles, nperm, b_idx)

    for i, b_idx in enumerate(tqdm(selected_bun_indices)):
        # Bring data into right format for seaborn
        data = {"position": [],
                "fa": [],
                "group": [],
                "subject": []}
        for j, subject in enumerate(subjects_A + subjects_B):
            for position in range(NR_POINTS):
                data["position"].append(position)
                data["subject"].append(subject)
                data["fa"].append(values[subject][b_idx][position])
                if subject in subjects_A:
                    data["group"].append("Group 0")
                else:
                    data["group"].append("Group 1")

        # Plot
        ax = sns.lineplot(x="position", y="fa", data=data, ax=axes[i], hue="group")
                          # units="subject", estimator=None, lw=1)  # each subject as single line

        ax.set(xlabel='position along tract', ylabel='metric')
        ax.set_title(bundles[b_idx])
        if analysis_type == "group" and i > 0:
            ax.legend_.remove()  # only show legend on first subplot
        if analysis_type == "correlation":
            ax.legend_.remove()

        # Significance testing without multiple correction of bundles
        if not correct_mult_tract_comp:
            values_allp = [values[s][b_idx] for s in subjects_A + subjects_B]  # [subjects, NR_POINTS]
            alphaFWE, clusterFWE = get_corrected_alpha(values_allp, meta_data, analysis_type, subjects_A, subjects_B,
                                                       alpha, bundles, nperm, b_idx)

        # Calc p-values
        pvalues = np.zeros(NR_POINTS)
        for jdx in range(NR_POINTS):
            if analysis_type == "group":
                values_controls = [values[s][b_idx][jdx] for s in subjects_A]
                values_patients = [values[s][b_idx][jdx] for s in subjects_B]
                pvalues[jdx] = scipy.stats.ttest_ind(values_controls, values_patients).pvalue
            else:
                values_controls = [values[s][b_idx][jdx] for s in subjects_A]
                _, pvalues[jdx] = scipy.stats.pearsonr(values_controls, meta_data["target"].values)

        # Plot significant areas
        if show_detailed_p:
            ax2 = axes[i].twinx()
            ax2.bar(range(len(pvalues)), -np.log10(pvalues), color="gray", edgecolor="none", alpha=0.5)
            ax2.plot([0, NR_POINTS-1], (-np.log10(alphaFWE),)*2, color="red", linestyle=":")
            ax2.set(xlabel='position', ylabel='-log10(p)')
        else:
            if FWE_method == "alphaFWE":
                sig_areas = get_significant_areas(pvalues, 1, alphaFWE)
            else:
                sig_areas = get_significant_areas(pvalues, clusterFWE, alpha)
            sig_areas = sig_areas * np.quantile(np.array(data["fa"]), 0.98)
            sig_areas[sig_areas == 0] = np.quantile(np.array(data["fa"]), 0.02)
            axes[i].plot(range(len(sig_areas)), sig_areas, color="red", linestyle=":")

        # Plot text
        if FWE_method == "alphaFWE":
            axes[i].annotate("alphaFWE:   {}".format(format_number(alphaFWE)),
                             (0, 0), (0, -35), xycoords='axes fraction', textcoords='offset points', va='top',
                             fontsize=10)
            axes[i].annotate("min p-value: {}".format(format_number(pvalues.min())),
                             (0, 0), (0, -45), xycoords='axes fraction', textcoords='offset points', va='top',
                             fontsize=10)
        else:
            axes[i].annotate("clusterFWE:   {}".format(clusterFWE),
                             (0, 0), (0, -35), xycoords='axes fraction', textcoords='offset points', va='top',
                             fontsize=10)

    plt.tight_layout()
    plt.savefig(output_path, dpi=200)


def main():
    parser = argparse.ArgumentParser(description="Test for significant differences and plot tractometry results.",
                                     epilog="Written by Jakob Wasserthal.")
    parser.add_argument("-i", metavar="subjects_file_path", dest="subjects_file",
                        help="txt file containing path of subjects", required=True)
    parser.add_argument("-o", metavar="plot_path", dest="output_path",
                        help="output png file containing plots", required=True)
    parser.add_argument("--mc", action="store_true", help="correct for multiple tract comparison",
                        default=False)
    parser.add_argument("--nperm", metavar="n", type=int, help="Number of permutations (default: 5000)",
                        default=5000)
    parser.add_argument("--alpha", metavar="a", type=float, help="The desired alpha level (default: 0.05)",
                        default=0.05)
    args = parser.parse_args()

    # Choose how to define significance: by corrected alphaFWE or by clusters of values smaller than uncorrected alpha
    # clusterFWE not recommended because misses highly significant areas just because cluster is not big enough.
    FWE_method = "alphaFWE"  # alphaFWE | clusterFWE
    # Show p-value for each position, not only significant areas
    show_detailed_p = False
    nperm = args.nperm

    # Significance testing with multiple correction for bundles
    correct_mult_tract_comp = args.mc

    print("Correcting for comparison of multiple tracts: {}".format(correct_mult_tract_comp))
    base_path, meta_data, selected_bundles = parse_subjects_file(args.subjects_file)

    if meta_data.columns[1] == "group":
        analysis_type = "group"
        print("Doing group analysis.")
        print("Number of subjects:")
        print("  group 0: {}".format((meta_data["group"] == 0).sum()))
        print("  group 1: {}".format((meta_data["group"] == 1).sum()))
    elif meta_data.columns[1] == "target":
        analysis_type = "correlation"
        print("Doing correlation analysis.")
        nperm = int(nperm / 5)
        print("Using one fifth ({}) of the number of permutations for correlation analysis "
              "because this has longer runtime".format(nperm))
        print("Number of subjects: {}".format(len(meta_data)))
    else:
        raise ValueError("Invalid second column header (only 'group' or 'target' allowed)")

    all_bundles = dataset_specific_utils.get_bundle_names("All_tractometry")[1:]

    # For testing: manually define subset of bundles
    # selected_bundles = ["CST_right", "CST_left", "CG_left"]

    values = {}
    for subject in meta_data["subject_id"]:
        raw = np.loadtxt(base_path.replace("SUBJECT_ID", subject), delimiter=";", skiprows=1).transpose()
        values[subject] = raw

    plot_tractometry_with_pvalue(values, meta_data, all_bundles, selected_bundles, args.output_path,
                                 args.alpha, FWE_method, analysis_type, correct_mult_tract_comp,
                                 show_detailed_p, nperm=nperm)


if __name__ == '__main__':
    main()
