#! /usr/bin/env python
import logging as log
import argparse
import pyxnat
import pandas as pd
import pydicom
import json
import os.path as op
import os
import tempfile
from tqdm import tqdm


class readable_dir(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        prospective_dir=values
        if not os.path.isdir(prospective_dir):
            msg = "readable_dir:{0} is not a valid path".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)
        if os.access(prospective_dir, os.R_OK):
            setattr(namespace,self.dest,prospective_dir)
        else:
            msg = "readable_dir:{0} is not a readable dir".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)

def check_xnat_item(a, x):
    projects = [e.label() for e in list(x.select.projects())]
    experiments = []
    for p in projects:
        exp = x.array.experiments(project_id=p).data
        experiments.extend([e['ID'] for e in exp])

    if a in projects:
        return 0
    elif a in experiments:
        return 1
    else:
        return -1

def get_scandate(experiment_id, x, t1_scan_label='T1_ALFA1'):

    columns = ['xsiType', 'xnat:imagescandata/type', 'xnat:imagescandata/ID']
    scans = x.array.scans(experiment_id=experiment_id, columns=columns).data
    t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
        if e['xnat:imagescandata/type'] == t1_scan_label}

    if len(t1_scans.items()) == 0:
        msg = 'No T1 found for %s: %s. Trying with all of them.'\
            %(experiment_id, [e['xnat:imagescandata/id'] for e in scans])
        log.warning(msg)
        t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
            if not e['xnat:imagescandata/id'].startswith('OT-')\
            and not e['xnat:imagescandata/id'].startswith('O-')}


    max_nb = sorted(t1_scans.keys())[-1]
    log.debug('Found scan: %s'%max_nb)
    scan = x.select.experiment(experiment_id).scan(max_nb)

    f = list(scan.resource('DICOM').files())[0]
    fp = op.join(tempfile.gettempdir(), 'test.dcm')
    f.get(dest=fp)
    d = pydicom.read_file(fp)

    if hasattr(d, 'AcquisitionDate'):
        acquisition_date = d.AcquisitionDate
    else:
        acquisition_date = d.AcquisitionDateTime[:8]
    os.remove(fp)
    return acquisition_date

def collect_mrscandates(x, project_id=None, experiment_id=None, max_rows=None):

    def __create_table__(data):
        df = pd.DataFrame(data, columns=('ID', 'label', 'subject_label', 'scandate'))
        df['scandate'] = pd.to_datetime(df['scandate'])
        df = df.set_index('ID').sort_index()
        return df

    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    elif not experiment_id is None:
        res = get_scandate(experiment_id, x)
        return res

    elif not project_id is None:
        data = []
        columns = ['label', 'subject_ID', 'subject_label']
        for e in tqdm(x.array.experiments(project_id=project_id, columns=columns).data[:max_rows]):
            try:
                log.debug('Experiment ID: %s Subject label: %s'%(e['ID'], e['subject_label']))
                row = [e['ID'], e['label'], e['subject_label']]
                d = get_scandate(e['ID'], x)
                row.append(d)
                data.append(row)
            except KeyboardInterrupt:
                return __create_table__(data)
            except Exception as exc:
                log.error('Failed with %s. Skipping it. (%s)'%(e['ID'], exc))
                continue
        return __create_table__(data)



def download_spm12(x, project_id=None, experiment_id=None,
        destdir=tempfile.gettempdir(), max_rows=5):
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:max_rows]:
                experiments.append(e['ID'])

        log.info('Now initiating download for %s experiments.'%len(experiments))
        for e in tqdm(experiments):
            log.debug(e)
            r = x.select.experiment(e).resource('SPM12_SEGMENT')
            if not r.exists():
                log.error('%s has no SPM12_SEGMENT resource'%e)
                continue
            dd = op.join(destdir, e)
            os.mkdir(dd)
            r.get(dest_dir=dd)
            r = x.select.experiment(e).resource('BBRC_VALIDATOR')
            pdf = {each.label():each for each in list(r.files()) \
                if 'SPM12SegmentValidator' in each.label() and \
                each.label().endswith('.pdf')}

            if not r.exists():
                log.error('%s has no BBRC_VALIDATOR resource'%e)
                continue
            if len(pdf.items()) == 0:
                log.error('%s has no SPM12 Validation Report'%e)
                continue
            f = pdf[sorted(pdf.keys())[-1]]
            f.get(dest=op.join(dd, f.label()))


def spm12_volumes(x, project_id=None, experiment_id=None, max_rows=None):
    import nibabel as nib
    import numpy as np
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:max_rows]:
                experiments.append(e['ID'])
        table = []
        for e in tqdm(experiments[:max_rows]):
            log.debug(e)
            try:
                r = x.select.experiment(e).resource('SPM12_SEGMENT')
                if not r.exists():
                    log.error('%s has no SPM12_SEGMENT resource'%e)
                    continue
                vols = [e]
                for kls in ['c1', 'c2', 'c3']:
                    f = [each for each in r.files() if each.id().startswith(kls)][0]
                    fp = tempfile.mkstemp('.nii.gz')[1]
                    f.get(fp)
                    d = nib.load(fp)
                    size = np.prod(d.header['pixdim'].tolist()[:4])
                    v = np.sum(d.dataobj) * size
                    vols.append(v)
                table.append(vols)
            except KeyboardInterrupt:
                return pd.DataFrame(table, columns=['ID', 'c1', 'c2', 'c3']).set_index('ID').sort_index()
            except Exception as exc:
                log.error('Failed for %s. Skipping it.'%e)
                log.error(exc)
                continue
        df = pd.DataFrame(table, columns=['ID', 'c1', 'c2', 'c3']).set_index('ID').sort_index()
        return df



def parse_args(command, args, x, destdir=tempfile.gettempdir(), test=False):
    max_rows = 1 if test else None
    commands = ['nifti', 'mrscandates', 'freesurfer', 'spm12']
    if command not in commands:
        msg = '%s not found (valid commands: %s)'%(command, commands)
        log.info(msg)
        raise Exception(msg)

    if command == 'mrscandates':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            log.info(msg)
        elif len(args) == 1:
            a = args[0] #should be a project or an experiment_id
            t = check_xnat_item(a, x)
            if t == 0:
                log.debug('Project detected: %s'%a)
                df = collect_mrscandates(x, project_id=a, max_rows=max_rows)
                if destdir == None:
                    destdir = tempfile.gettempdir()
                from datetime import datetime
                dt = datetime.today().strftime('%Y%m%d')
                fn = 'bx_%s_%s_%s.xls'%(command, a, dt)
                fp = op.join(destdir, fn)
                log.info('Saving it in %s'%fp)
                df.to_excel(fp)

            elif t == 1:
                log.debug('Experiment detected: %s'%a)
                sd = collect_mrscandates(x, experiment_id=a)
                print(sd)
                log.info('Scan date: %s'%sd)
            else:
                log.error('No project/experiment found: %s'%a)


    elif command == 'freesurfer':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)
            if subcommand in ['thickness', 'aparc']:
                pass
            elif subcommand == 'aseg':
                pass
            elif subcommand == 'hippoSfVolumes':
                pass


    elif command == 'spm12':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            log.info(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            log.info(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            log.info(a)
            t = check_xnat_item(a, x)
            if subcommand == 'maps':
                if t == 0:
                    log.debug('Project detected: %s'%a)
                    download_spm12(x, project_id=a, destdir=destdir, max_rows=max_rows)
                elif t == 1:
                    log.debug('Experiment detected: %s'%a)
                    download_spm12(x, experiment_id=a, destdir=destdir)
                else:
                    log.error('No project/experiment found: %s'%a)

            elif subcommand == 'volumes':
                if t == 0:
                    df = spm12_volumes(x, project_id=a, max_rows=max_rows)
                elif t == 1:
                    df = spm12_volumes(x, experiment_id=a)
                else:
                    log.error('No project/experiment found: %s'%a)

                from datetime import datetime
                dt = datetime.today().strftime('%Y%m%d')
                fn = 'bx_%s_%s_%s.xls'%(command, a, dt)
                fp = op.join(destdir, fn)
                log.info('Saving it in %s'%fp)
                df.to_excel(fp)

    elif command == 'nifti':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            _type = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)



def create_parser():
    import argparse
    parser = argparse.ArgumentParser(description='bx')
    parser.add_argument('command', help='BX command')
    parser.add_argument('args', help='BX command', nargs="*")
    parser.add_argument('--config', help='XNAT configuration file',
        required=False, type=argparse.FileType('r'), default='~/.xnat.cfg')
    parser.add_argument('--dest', help='Destination folder',
        required=False, action=readable_dir)
    parser.add_argument('--verbose', '-V', action='store_true', default=False,
        help='Display verbosal information (optional)', required=False)
    return parser

if __name__=="__main__" :
    parser = create_parser()
    args = parser.parse_args()
    if args.verbose:
        log.basicConfig(level=log.DEBUG)
    else:
        log.basicConfig(level=log.INFO)
    if not args.dest is None:
        dd = args.dest
    else:
        dd = json.load(open(args.config.name)).get('destination', None)
    if dd is None:
        dd = tempfile.gettempdir()
    log.info('Output folder: %s'%dd)
    x = pyxnat.Interface(config=args.config.name)
    parse_args(args.command, args.args, x, dd)
