#! /usr/bin/env python
import logging as log
import argparse
import pyxnat
import pandas as pd
import pydicom
import json
import os.path as op
import os
import tempfile
from tqdm import tqdm
from datetime import datetime
import nibabel as nib
import numpy as np
import shutil

class readable_dir(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        prospective_dir=values
        if not os.path.isdir(prospective_dir):
            msg = "readable_dir:{0} is not a valid path".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)
        if os.access(prospective_dir, os.R_OK):
            setattr(namespace,self.dest,prospective_dir)
        else:
            msg = "readable_dir:{0} is not a readable dir".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)

def check_xnat_item(a, x):
    projects = [e.label() for e in list(x.select.projects())]
    experiments = []
    for p in projects:
        exp = x.array.experiments(project_id=p).data
        experiments.extend([e['ID'] for e in exp])

    if a in projects:
        return 0
    elif a in experiments:
        return 1
    else:
        return -1

def get_scandate(experiment_id, x, t1_scan_label='T1_ALFA1'):

    columns = ['xsiType', 'xnat:imagescandata/type', 'xnat:imagescandata/ID']
    scans = x.array.scans(experiment_id=experiment_id, columns=columns).data
    t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
        if e['xnat:imagescandata/type'] == t1_scan_label}

    if len(t1_scans.items()) == 0:
        msg = 'No T1 found for %s: %s. Trying with all of them.'\
            %(experiment_id, [e['xnat:imagescandata/id'] for e in scans])
        log.warning(msg)
        t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
            if not e['xnat:imagescandata/id'].startswith('OT-')\
            and not e['xnat:imagescandata/id'].startswith('O-')}


    max_nb = sorted(t1_scans.keys())[-1]
    log.debug('Found scan: %s'%max_nb)
    scan = x.select.experiment(experiment_id).scan(max_nb)

    f = list(scan.resource('DICOM').files())[0]
    fp = op.join(tempfile.gettempdir(), 'test.dcm')
    f.get(dest=fp)
    d = pydicom.read_file(fp)

    if hasattr(d, 'AcquisitionDate'):
        acquisition_date = d.AcquisitionDate
    else:
        acquisition_date = d.AcquisitionDateTime[:8]
    os.remove(fp)
    return acquisition_date

def collect_mrdates(x, project_id=None, experiment_id=None, max_rows=None,
        overwrite=False):

    def __create_table__(data):
        df = pd.DataFrame(data, columns=('ID', 'label', 'subject_label', 'scandate'))
        df['scandate'] = pd.to_datetime(df['scandate'])
        df = df.set_index('ID').sort_index()
        return df

    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    elif not experiment_id is None:
        res = get_scandate(experiment_id, x)
        return res

    elif not project_id is None:
        data = []
        columns = ['label', 'subject_ID', 'subject_label']
        for e in tqdm(x.array.experiments(project_id=project_id, columns=columns).data[:max_rows]):
            try:
                log.debug('Experiment ID: %s Subject label: %s'%(e['ID'], e['subject_label']))
                row = [e['ID'], e['label'], e['subject_label']]
                d = get_scandate(e['ID'], x)
                row.append(d)
                data.append(row)
            except KeyboardInterrupt:
                return __create_table__(data)
            except Exception as exc:
                log.error('Failed with %s. Skipping it. (%s)'%(e['ID'], exc))
                continue
        return __create_table__(data)



def download_spm12(x, project_id=None, experiment_id=None,
        destdir=tempfile.gettempdir(), max_rows=5, overwrite=False):
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:max_rows]:
                experiments.append(e['ID'])

        log.info('Now initiating download for %s experiments.'%len(experiments))
        for e in tqdm(experiments):
            log.debug(e)
            r = x.select.experiment(e).resource('SPM12_SEGMENT')
            if not r.exists():
                log.error('%s has no SPM12_SEGMENT resource'%e)
                continue
            dd = op.join(destdir, e)
            if op.isdir(dd) and not overwrite:
                msg = '%s already exists. Skipping %s.'%(dd, e)
                log.error(msg)
            else:
                if op.isdir(dd) and overwrite:
                    msg = '%s already exists. Overwriting %s.'%(dd, e)
                    log.warning(msg)
                    shutil.rmtree(dd)

                os.mkdir(dd)
                r.get(dest_dir=dd)
                r = x.select.experiment(e).resource('BBRC_VALIDATOR')
                pdf = {each.label():each for each in list(r.files()) \
                    if 'SPM12SegmentValidator' in each.label() and \
                    each.label().endswith('.pdf')}

                if not r.exists():
                    log.error('%s has no BBRC_VALIDATOR resource'%e)
                    continue
                if len(pdf.items()) == 0:
                    log.error('%s has no SPM12 Validation Report'%e)
                    continue
                f = pdf[sorted(pdf.keys())[-1]]
                f.get(dest=op.join(dd, f.label()))


def spm12_volumes(x, project_id=None, experiment_id=None, max_rows=None,
        overwrite=False):

    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:max_rows]:
                experiments.append(e['ID'])
        table = []
        for e in tqdm(experiments[:max_rows]):
            log.debug(e)
            try:
                r = x.select.experiment(e).resource('SPM12_SEGMENT')
                if not r.exists():
                    log.error('%s has no SPM12_SEGMENT resource'%e)
                    continue
                vols = [e]
                for kls in ['c1', 'c2', 'c3']:
                    f = [each for each in r.files() if each.id().startswith(kls)][0]
                    fp = tempfile.mkstemp('.nii.gz')[1]
                    f.get(fp)
                    d = nib.load(fp)
                    size = np.prod(d.header['pixdim'].tolist()[:4])
                    v = np.sum(d.dataobj) * size
                    vols.append(v)
                table.append(vols)
            except KeyboardInterrupt:
                return pd.DataFrame(table, columns=['ID', 'c1', 'c2', 'c3']).set_index('ID').sort_index()
            except Exception as exc:
                log.error('Failed for %s. Skipping it.'%e)
                log.error(exc)
                continue
        df = pd.DataFrame(table, columns=['ID', 'c1', 'c2', 'c3']).set_index('ID').sort_index()
        return df




def download_freesurfer6(x, project_id=None, experiment_id=None,
        destdir=tempfile.gettempdir(), max_rows=5, overwrite=False):
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:max_rows]:
                experiments.append(e['ID'])

        log.info('Now initiating download for %s experiments.'%len(experiments))
        for e in tqdm(experiments):
            log.debug(e)
            r = x.select.experiment(e).resource('FREESURFER6')
            if not r.exists():
                log.error('%s has no FREESURFER6 resource'%e)
                continue
            dd = op.join(destdir, e)
            if op.isdir(dd) and not overwrite:
                msg = '%s already exists. Skipping %s.'%(dd, e)
                log.error(msg)
            else:
                if op.isdir(dd) and overwrite:
                    msg = '%s already exists. Overwriting %s.'%(dd, e)
                    log.warning(msg)

                os.mkdir(dd)

                r.get(dest_dir=dd)
                r = x.select.experiment(e).resource('BBRC_VALIDATOR')
                pdf = {each.label():each for each in list(r.files()) \
                    if 'FreeSurferValidator' in each.label() and \
                    each.label().endswith('.pdf')}

                if not r.exists():
                    log.error('%s has no BBRC_VALIDATOR resource'%e)
                    continue
                if len(pdf.items()) == 0:
                    log.error('%s has no FreeSurfer Validation Report'%e)
                    continue
                f = pdf[sorted(pdf.keys())[-1]]
                od = op.join(dd, f.label())
                f.get(dest=od)

def freesurfer6_measurements(x, func, project_id=None, experiment_id=None,
        max_rows=None, overwrite=False):

    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        columns = ['label', 'subject_ID', 'subject_label']

        if not experiment_id is None:
            experiments = [x.array.experiments(experiment_id=experiment_id, columns=columns).data[0]]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=columns).data[:max_rows]:
                experiments.append(e)
        table = []
        for e in tqdm(experiments[:max_rows]):
            log.debug(e)
            try:
                s = e['subject_label']
                r = x.select.experiment(e['ID']).resource('FREESURFER6')
                if not r.exists():
                    log.error('%s has no FREESURFER6 resource'%e)
                    continue
                if func == 'aparc':
                    volumes = r.aparc()
                elif func == 'aseg':
                    volumes = r.aseg()
                elif func == 'hippoSfVolumes':
                    volumes = r.hippoSfVolumes()
                volumes['subject'] = s
                table.append(volumes)
            except KeyboardInterrupt:
                return pd.concat(table).set_index('subject').sort_index()
            except Exception as exc:
                log.error('Failed for %s. Skipping it.'%e)
                log.error(exc)
                continue
        hippoSfVolumes = pd.concat(table).set_index('subject').sort_index()
        return hippoSfVolumes



def parse_args(command, args, x, destdir=tempfile.gettempdir(), overwrite=False,
        test=False):
    max_rows = 1 if test else None
    commands = ['nifti', 'mrdates', 'freesurfer', 'spm12']
    if command not in commands:
        msg = '%s not found (valid commands: %s)'%(command, commands)
        log.info(msg)
        raise Exception(msg)
    log.debug('Command: %s'%command)

    if command == 'mrdates':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            log.info(msg)
        elif len(args) == 1:
            a = args[0] #should be a project or an experiment_id
            t = check_xnat_item(a, x)
            if t == 0:
                log.debug('Project detected: %s'%a)
                df = collect_mrdates(x, project_id=a, max_rows=max_rows)
                if destdir == None:
                    destdir = tempfile.gettempdir()

                dt = datetime.today().strftime('%Y%m%d_%H%M%S')
                fn = 'bx_%s_%s_%s.xls'%(command, a, dt)
                fp = op.join(destdir, fn)
                log.info('Saving it in %s'%fp)
                df.to_excel(fp)

            elif t == 1:
                log.debug('Experiment detected: %s'%a)
                sd = collect_mrdates(x, experiment_id=a)
                print(sd)
                log.info('Scan date: %s'%sd)
            else:
                log.error('No project/experiment found: %s'%a)


    elif command == 'freesurfer':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)
            if subcommand in ['aparc', 'aseg', 'hippoSfVolumes']:
                max_rows = 25 if test else None

                if t == 0:
                    df = freesurfer6_measurements(x, subcommand, project_id=a, max_rows=max_rows)
                elif t == 1:
                    df = freesurfer6_measurements(x, subcommand, experiment_id=a)
                else:
                    log.error('No project/experiment found: %s'%a)
                if t != -1:
                    dt = datetime.today().strftime('%Y%m%d_%H%M%S')
                    fn = 'bx_%s_%s_%s_%s.xls'%(command, subcommand, a, dt)
                    fp = op.join(destdir, fn)
                    log.info('Saving it in %s'%fp)
                    df.to_excel(fp)
            elif subcommand == 'files':
                if t == 0:
                    log.debug('Project detected: %s'%a)
                    download_freesurfer6(x, project_id=a, destdir=destdir, max_rows=max_rows, overwrite=overwrite)
                elif t == 1:
                    log.debug('Experiment detected: %s'%a)
                    download_freesurfer6(x, experiment_id=a, destdir=destdir, overwrite=overwrite)
                else:
                    log.error('No project/experiment found: %s'%a)


    elif command == 'spm12':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            log.info(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            log.info(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            log.info(a)
            t = check_xnat_item(a, x)
            if subcommand == 'files':
                if t == 0:
                    log.debug('Project detected: %s'%a)
                    download_spm12(x, project_id=a, destdir=destdir, max_rows=max_rows, overwrite=overwrite)
                elif t == 1:
                    log.debug('Experiment detected: %s'%a)
                    download_spm12(x, experiment_id=a, destdir=destdir, overwrite=overwrite)
                else:
                    log.error('No project/experiment found: %s'%a)

            elif subcommand == 'volumes':
                if t == 0:
                    log.debug('Project detected: %s'%a)
                    df = spm12_volumes(x, project_id=a, max_rows=max_rows)
                elif t == 1:
                    log.debug('Experiment detected: %s'%a)
                    df = spm12_volumes(x, experiment_id=a)
                else:
                    log.error('No project/experiment found: %s'%a)
                if t != -1:

                    dt = datetime.today().strftime('%Y%m%d_%H%M%S')
                    fn = 'bx_%s_%s_%s.xls'%(command, a, dt)
                    fp = op.join(destdir, fn)
                    log.info('Saving it in %s'%fp)
                    df.to_excel(fp)

    elif command == 'nifti':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            _type = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)



def create_parser():
    import argparse
    cfgfile = op.join(op.expanduser('~'), '.xnat.cfg')
    parser = argparse.ArgumentParser(description='bx')
    parser.add_argument('command', help='BX command')
    parser.add_argument('args', help='BX command', nargs="*")
    parser.add_argument('--config', help='XNAT configuration file',
        required=False, type=argparse.FileType('r'), default=cfgfile)
    parser.add_argument('--dest', help='Destination folder',
        required=False, action=readable_dir)
    parser.add_argument('--verbose', '-V', action='store_true', default=False,
        help='Display verbosal information (optional)', required=False)
    parser.add_argument('--overwrite', '-O', action='store_true', default=False,
        help='Overwrite', required=False)
    return parser

if __name__=="__main__" :
    parser = create_parser()
    args = parser.parse_args()
    if args.verbose:
        log.basicConfig(level=log.DEBUG)
    else:
        log.basicConfig(level=log.INFO)
    if not args.dest is None:
        dd = args.dest
    else:
        dd = json.load(open(args.config.name)).get('destination', None)
    if dd is None:
        dd = tempfile.gettempdir()
    log.info('Output folder: %s'%dd)
    x = pyxnat.Interface(config=args.config.name)
    parse_args(args.command, args.args, x, dd, overwrite=args.overwrite)
