#!python

import argparse
import pynsp as nsp

__version__ = nsp.__version__


def main():
    parser = argparse.ArgumentParser(prog='pynsp',
                                     description="Command-line PyNSP tool")
    parser.add_argument("-v", "--version", action='version', version ='%(prog)s v{}'.format(__version__))

    subparsers = parser.add_subparsers(title='Sub-commands',
                                       description='List of available sub-commands',
                                       help='description',
                                       dest='function',
                                       metavar='command')

    qc = subparsers.add_parser("qc", help="QC parameters calculation")
    nuis = subparsers.add_parser("nuisance", help="Nuisance signal regression")
    reho = subparsers.add_parser("reho", help="Regional Homogeneity")
    alff = subparsers.add_parser("alff", help="Amplitude of Low Frequency Fluctuation")
    rcon = subparsers.add_parser("roi-conn", help="ROI-based Connectivity")

    # 1. Quality control
    qc.add_argument("-i", "--input", help="input file", type=str, required=True)
    qc.add_argument("-p", "--param", help="motion parameter file", type=str, required=True)
    qc.add_argument("-m", "--mask", help="brain mask file", type=str, default=None)
    qc.add_argument("-o", "--output", help="output file prefix", type=str, required=True)

    # 2. Nuisance signal regression
    nuis.add_argument("-i", "--input", help="input file", type=str, required=True)
    nuis.add_argument("-p", "--param", help="nuisance parameter files", nargs='*', type=str, required=False)
    nuis.add_argument("-t", "--dt", help="sampling rate", type=float, required=False)
    nuis.add_argument("-b", "--band", help="band pass filter range (Hz)", nargs=2, type=float, required=False)
    nuis.add_argument("-m", "--mask", help="brain mask file", type=str, required=False)
    nuis.add_argument("-o", "--output", help="output file prefix", type=str, required=True)
    nuis.add_argument("--polort", help="polynomial regressor", type=int, default=3)
    nuis.add_argument("--decimals", help="decimals for round up", type=int, default=None)

    # 3. Regional Homoheneity
    reho.add_argument("-i", "--input", help="input file", type=str, required=True)
    reho.add_argument("-n", "--NN", help="level of neighboring voxel involvement", type=int, default=3)
    reho.add_argument("-m", "--mask", help="brain mask file", type=str, required=False)
    reho.add_argument("-o", "--output", help="output file prefix", type=str, required=True)

    # 4. Amplitude of Low Frequency Fluctuation
    alff.add_argument("-i", "--input", help="input file", type=str, required=True)
    alff.add_argument("-b", "--band", help="LFF filter range (Hz)", nargs=2, type=float, required=True)
    alff.add_argument("-m", "--mask", help="brain mask file", type=str, required=False)
    alff.add_argument("-t", "--dt", help="sampling rate", type=float, required=False)
    alff.add_argument("-o", "--output", help="output file prefix", type=str, required=True)

    # 5. ROI-based connectivity
    rcon.add_argument("-i", "--input", help="input file", type=str, required=True)
    rcon.add_argument("-a", "--atlas", help="atlas file", type=str, required=True)
    rcon.add_argument("-l", "--label", help="atlas label file", type=str, required=False)
    rcon.add_argument("-m", "--mask", help="brain mask file", type=str, required=False)
    rcon.add_argument("-o", "--output", help="output file prefix", type=str, required=True)
    rcon.add_argument("--PCA", help="use PCA", type=bool, default=True)
    rcon.add_argument("--FDR", help="FDR correction", type=bool, default=True)

    args = parser.parse_args()
    if args.function == 'qc':
        import nibabel as nib
        import numpy as np
        obj = nsp.QC(args.input, mparam_path=args.param, mask_path=args.mask, calc_all=True)
        obj.mparam.to_excel("{}_mparam.xlsx".format(args.output))
        obj.FD.to_excel("{}_FD.xlsx".format(args.output))
        obj.DVARS.to_excel("{}_DVARS.xlsx".format(args.output))
        std = nib.Nifti1Image(obj.STD, nib.load(args.input).affine)
        std.to_filename("{}_STD.nii.gz".format(args.output))
        np.savetxt("{}_VWI.txt.gz".format(args.output), obj.VWI)

    elif args.function == 'nuisance':
        obj = nsp.RSFC(args.input, ort_paths=args.param, mask_path=args.mask,
                       order=args.polort, band=args.band, dt=args.dt)
        obj.nuisance_denoising(order=args.polort)
        if args.band is not None:
            obj.bandpass_filtering()
        obj.save_as(args.output, obj.processed[-1])

    elif args.function == 'reho':
        obj = nsp.RSFC(args.input, mask_path=args.mask)
        obj.calc_ReHo(NN=args.NN)
        obj.save_as(args.output, obj.processed[0])

    elif args.function == 'alff':
        obj = nsp.RSFC(args.input, mask_path=args.mask, band=args.band, dt=args.dt)
        obj.calc_ALFF()
        for step in obj.processed:
            if '_freq' in step:
                surfix = step.split('.')[-1]
                with open("{}_{}.txt".format(args.output, surfix), 'w') as f:
                    f.write(", ".join(obj[step].astype(str).tolist()))
            else:
                surfix = step.split('.')[-1]
                obj.save_as("{}_{}".format(args.output, surfix), step)

    elif args.function == 'roi-conn':
        from pynsp.methods.tools import parse_label
        import os

        if args.label is not None:
            if os.path.exists(args.label):
                import nibabel as nib
                label_dic, _ = parse_label(args.label)
                max_index = max(map(int, label_dic.keys()))
                atlas_label = []
                for i in range(max_index):
                    try:
                        atlas_label.append(label_dic[i+1])
                    except:
                        atlas_label.append('idx-{}'.format(str(i+1).zfill(3)))
            else:
                raise Exception
        else:
            import nibabel as nib
            atlas = nib.load(args.atlas).get_data()
            atlas_label = ["idx-{}".format(str(int(i)).zfill(3)) for i in sorted(list(set(atlas.flatten())))]

        obj = nsp.RSFC(args.input, mask_path=args.mask)
        if args.FDR is True:
            fwe = "Benjamini-Hochberg"
        else:
            fwe = None
        obj.calc_ROI_CC(args.atlas, atlas_label, use_PCA=args.PCA, fwe=fwe)
        for step in obj.processed:
            surfix = step.split('.')[-1]
            obj.save_as("{}_{}".format(args.output, surfix), step)


if __name__ == '__main__':
    main()