#!python

from brkraw import BrukerLoader, __version__
import argparse
import os, re


def main():
    parser = argparse.ArgumentParser(prog='brkraw',
                                     description="Command line tool of Bruker Rawdata Handler")
    parser.add_argument("-v", "--version", action='version', version='%(prog)s v{}'.format(__version__))

    subparsers = parser.add_subparsers(title='Sub-commands',
                                       description='brkraw provides two major function reporting '
                                                   'contents on bruker raw data '
                                                   'and converting image data into NifTi format.',
                                       help='description',
                                       dest='function',
                                       metavar='command')

    summary = subparsers.add_parser("summary", help='Print out data summary')
    summary.add_argument("path", help="Folder location for the Bruker raw data", type=str)

    chkbckstatus = subparsers.add_parser("chk_bckstatus", help='Check the backup status')
    chkbckstatus.add_argument("raw_path", help="Folder location of the Bruker raw datasets", type=str)
    chkbckstatus.add_argument("backup_path", help="Folder location of the backed-up datasets", type=str)
    chkbckstatus.add_argument("-l", "--logging", help="option for logging output instead printing", action='store_true')

    gui = subparsers.add_parser("gui", help='Start GUI')
    gui.add_argument("-i", "--input", help="Folder location for the Bruker raw data", type=str, default=None)
    gui.add_argument("-o", "--output", help="Folder location for converted NifTi data", type=str, default=None)

    nii = subparsers.add_parser("tonii", help='Convert to NifTi format')
    nii.add_argument("path", help="Folder location for the Bruker raw data", type=str)
    nii.add_argument("-b", "--bids", help="Create JSON file with BIDS standard MRI acqusition parameter.", action='store_true')
    nii.add_argument("-o", "--output", help="Filename w/o extension to export NifTi image", type=str, default=False)
    nii.add_argument("-r", "--recoid", help="RECO ID (if scan_id has multiple reconstruction data)", type=int, default=1)
    nii.add_argument("-s", "--scanid", help="Scan ID", type=str)

    niiall = subparsers.add_parser("tonii_all", help="Convert All Datasets inside input path, "
                                                     "Caution: Don't use this function on console computer!! "
                                                     "It will take forever!!")
    niiall.add_argument("path", help="Path of dataset root folder", type=str)
    niiall.add_argument("-b", "--bids", help="Create JSON file with BIDS standard MRI acqusition parameter.",
                        action='store_true')

    args = parser.parse_args()

    if args.function == 'summary':
        path = args.path
        if any([os.path.isdir(path), ('zip' in path), ('PvDataset' in path)]):
            study = BrukerLoader(path)
            study.summary()
        else:
            list_path = [d for d in os.listdir('.') if (any([os.path.isdir(d),
                                                             ('zip' in d),
                                                             ('PvDataset' in d)]) and re.search(path, d, re.IGNORECASE))]
            for p in list_path:
                study = BrukerLoader(p)
                study.summary()

    elif args.function == 'chk_bckstatus':
        from os.path import join as opj, isdir, isfile, exists
        import brkraw
        import tqdm
        import pickle
        import zipfile

        rpath = args.raw_path
        bpath = args.backup_path
        bar_fmt = '{l_bar}{bar:20}{r_bar}{bar:-20b}'

        if args.logging:
            import datetime
            today = datetime.date.today().strftime("%y%m%d")
            fobj = open(opj(bpath, '{}_backup_status.log'.format(today)), 'w')
        else:
            import sys
            fobj = sys.stdout

        # load backed-up cache
        backup_cache = opj(bpath, '.brkraw_cache')
        if exists(backup_cache):
            with open(backup_cache, 'rb') as f:
                cached_dataset = pickle.load(f)
            if 'duplicated' not in cached_dataset.keys(): # this component added later, so need update for old version
                cached_dataset['duplicated'] = dict()
        else:
            cached_dataset = dict(failed_backup=[],  # data backup failed (raw data is not exist)
                                  failed_raw=[],     # failed data acquisition
                                  duplicated=dict(),
                                  completed=dict(),  # data backup completed
                                  incompleted=dict(),  # data backup incompleted (need to re-run)
                                  backup_required=dict(),  # data backup is needed (backup data is not exist)
                                  garbages=dict())              # garbage data (no scan info)

        # parse list of datasets
        list_of_raw = sorted([d for d in os.listdir(rpath) if isdir(opj(rpath, d))])
        list_of_candidates = [d for d in os.listdir(bpath) if (isfile(opj(bpath, d)) and (('zip' in d) or ('PvDatasets' in d)))]
        list_of_candidates = [d for d in list_of_candidates if d not in cached_dataset['completed'].keys()]

        if len(cached_dataset['completed'].items()):
            print('Checking completed backup datasets...')
            cache_tested = []
            for bck_path, raw_path in tqdm.tqdm(cached_dataset['completed'].items(),
                                                bar_format=bar_fmt):
                if raw_path not in cache_tested:
                    raw_in_list = [rp for rp in cached_dataset['completed'].values() if rp == raw_path]
                    if len(raw_in_list) > 1:
                        if raw_path not in cached_dataset['duplicated'].keys():
                            cached_dataset['duplicated'][raw_path] = []
                        if bck_path not in cached_dataset['duplicated'][raw_path]:
                            cached_dataset['duplicated'][raw_path].append(bck_path)
                    else:
                        pass
                    cache_tested.append(raw_path)
                else:
                    if bck_path not in cached_dataset['duplicated'][raw_path]:
                        cached_dataset['duplicated'][raw_path].append(bck_path)
            del cache_tested

        if len(cached_dataset['duplicated']):
            print('Checking duplicated backup datasets...')
            for raw_path, bck_paths in tqdm.tqdm(cached_dataset['duplicated'].items(),
                                                 bar_format=bar_fmt):
                for bck_path in bck_paths:
                    if not os.path.exists(os.path.join(bpath, bck_path)):
                        cached_dataset['duplicated'][raw_path].remove(bck_path)
                        print(" -'{}'[duplicated backup for {}] has been removed.".format(bck_path, raw_path))
                if len(cached_dataset['duplicated'][raw_path]) == 1:
                    del cached_dataset['duplicated'][raw_path]

        if len(cached_dataset['failed_backup']):
            print('Checking failed backup datasets...')
            for bck_path in tqdm.tqdm(cached_dataset['failed_backup'],
                                      bar_format=bar_fmt):
                if not exists(opj(bpath, bck_path)):
                    print(" -'{}' has been removed.".format(bck_path))
                    cached_dataset['failed_backup'].remove(bck_path)

        if len(cached_dataset['failed_raw']):
            print('Checking failed raw datasets...')
            for raw_path in tqdm.tqdm(cached_dataset['failed_raw'],
                                      bar_format=bar_fmt):
                if not exists(opj(rpath, raw_path)):
                    print(" -'{}' has been removed.".format(raw_path))
                    cached_dataset['failed_raw'].remove(raw_path)
                else:
                    if raw_path in cached_dataset['completed'].values():
                        print(" -'{}' has been backed-up.".format(raw_path))
                        cached_dataset['failed_raw'].remove(raw_path)

        if len(cached_dataset['incompleted'].items()):
            print('Checking incompleted backup datasets...')
            for bck_path, raw_path in tqdm.tqdm(cached_dataset['incompleted'].items(),
                                                bar_format=bar_fmt):
                if not exists(opj(bpath, bck_path)):
                    print(" -'{}' has been removed.".format(bck_path))
                    del cached_dataset['incompleted'][bck_path]
                else:
                    bck_data = brkraw.load(opj(bpath, bck_path))
                    raw_data = brkraw.load(opj(rpath, raw_path))
                    if bck_data.num_recos == raw_data.num_recos:
                        del cached_dataset['incompleted'][bck_path]
                        cached_dataset['completed'][bck_path] = raw_path

        if len(cached_dataset['garbages'].items()):
            print('Checking garbages datasets...')
            for bck_path, raw_path  in tqdm.tqdm(cached_dataset['garbages'].items(),
                                                 bar_format=bar_fmt):
                if not exists(opj(bpath, bck_path)):
                    print(" -'{}' has been removed.".format(bck_path))
                    if not exists(opj(rpath, raw_path)):
                        print(" -'{}' has been removed.".format(raw_path))
                        del cached_dataset['garbages'][bck_path]
                    else:
                        print(" -'{}' was not removed.".format(raw_path))

        print('Updating condition of backup datasets...')
        for bck_path in tqdm.tqdm(list_of_candidates,
                                  bar_format=bar_fmt):
            if not zipfile.is_zipfile(opj(bpath, bck_path)):
                if bck_path not in cached_dataset['failed_backup']:
                    cached_dataset['failed_backup'].append(bck_path)
            else:
                bck_data = brkraw.load(opj(bpath, bck_path))
                raw_path = bck_data._pvobj.path
                if bck_data.is_pvdataset():
                    if exists(opj(rpath, raw_path)):
                        # if the rawdata still exists, check backed data is same as original
                        raw_data = brkraw.load(opj(rpath, raw_path))
                        if raw_data.num_recos == bck_data.num_recos:
                            cached_dataset['completed'][bck_path] = raw_path
                        else:
                            cached_dataset['incompleted'][bck_path] = raw_path
                    else:
                        cached_dataset['completed'][bck_path] = raw_path
                        if bck_data.num_recos < 1:
                            cached_dataset['garbages'][bck_path] = raw_path

        cached_dataset['backup_required'] = dict() #reset the list

        print('Checking backup status of raw datasets...')
        for raw_path in tqdm.tqdm(list_of_raw,
                                  bar_format=bar_fmt):
            if raw_path not in cached_dataset['completed'].values():
                try:
                    raw_data = brkraw.load(opj(rpath, raw_path))
                    cached_dataset['backup_required'][raw_path] = raw_data._pvobj.user_name
                except:
                    if raw_path not in cached_dataset['failed_raw']:
                        cached_dataset['failed_raw'].append(raw_path)

        print('**** Summary ****', file=fobj)
        if len(cached_dataset['backup_required'].keys()):
            print('>> The list of raw datasets need backup... [no backup file]', file=fobj)
            for raw_path, user_name in cached_dataset['backup_required'].items():
                print(' -{} (user:{})'.format(raw_path, user_name), file=fobj)
        if len(cached_dataset['incompleted'].keys()):
            print('\n>> The list of raw datasets need re-backup... [number of reconstructed image mismatch between raw and backup]', file=fobj)
            for bck_path, raw_path in cached_dataset['incompleted'].items():
                print(' -{} (backup:{})'.format(raw_path, bck_path), file=fobj)
        if len(cached_dataset['duplicated'].keys()):
            print('\n>> The list of duplicated backup datasets... [the same rawdata backed up on separate file]', file=fobj)
            for raw_path, bck_paths in cached_dataset['duplicated'].items():
                print(' -{}: {}'.format(raw_path, bck_paths), file=fobj)
        if len(cached_dataset['garbages'].keys()):
            print('\n>> The raw datasets can be removed... [no reconstructed image]', file=fobj)
            for bck_path, raw_path in cached_dataset['garbages'].items():
                print(' -{} (backup:{})'.format(raw_path, bck_path), file=fobj)
        if len(cached_dataset['failed_raw']):
            print('\n>> The list of failed raw datasets... [issue found in rawdata]', file=fobj)
            for raw_path in cached_dataset['failed_raw']:
                print(' -{}'.format(raw_path), file=fobj)
        if len(cached_dataset['failed_backup']):
            print('\n>> The list of failed backup datasets... [crashed backup file]', file=fobj)
            for raw_path in cached_dataset['failed_backup']:
                print(' -{}'.format(raw_path), file=fobj)

        # save cache
        with open(backup_cache, 'wb') as f:
            pickle.dump(cached_dataset, f)

    elif args.function == 'gui':
        ipath = args.input
        opath = args.output
        from brkraw.ui.main_win import MainWindow
        root = MainWindow()
        if ipath != None:
            root._path = ipath
            root._extend_layout()
            root._load_dataset()
        if opath != None:
            root._output = opath
        root.mainloop()

    elif args.function == 'tonii':
        path = args.path
        scan_id = args.scanid
        reco_id = args.recoid
        study = BrukerLoader(path)
        if args.output:
            output = args.output
        else:
            output = '{}_{}'.format(study._pvobj.subj_id,study._pvobj.study_id)
        if scan_id:
            output_fname = '{}-{}-{}'.format(output, scan_id, reco_id)
            try:
                study.save_as(scan_id, reco_id, output_fname)
                if args.bids:
                    study.save_json(scan_id, reco_id, output_fname)
                print('NifTi file is genetared... [{}]'.format(output_fname))
            except Exception as e:
                print('[Warning]::{}'.format(e))
        else:
            for scan_id, recos in study._pvobj.avail_reco_id.items():
                for reco_id in recos:
                    output_fname = '{}-{}-{}'.format(output, str(scan_id).zfill(2), reco_id)
                    try:
                        study.save_as(scan_id, reco_id, output_fname)
                        if args.bids:
                            study.save_json(scan_id, reco_id, output_fname)
                        print('NifTi file is genetared... [{}]'.format(output_fname))
                    except Exception as e:
                        print('[Warning]::{}'.format(e))

    elif args.function == 'tonii_all':
        path = args.path
        from os.path import join as opj, isdir, isfile
        list_of_raw = sorted([d for d in os.listdir(path) if isdir(opj(path, d)) \
                              or (isfile(opj(path, d)) and (('zip' in d) or ('PvDataset' in d)))])
        base_path = 'Data'
        try:
            os.mkdir(base_path)
        except:
            pass
        for raw in list_of_raw:
            sub_path = os.path.join(path, raw)
            study = BrukerLoader(sub_path)
            if len(study._pvobj.avail_scan_id):
                subj_path = os.path.join(base_path, 'sub-{}'.format(study._pvobj.subj_id))
                try:
                    os.mkdir(subj_path)
                except:
                    pass
                sess_path = os.path.join(subj_path, 'ses-{}'.format(study._pvobj.study_id))
                try:
                    os.mkdir(sess_path)
                except:
                    pass
                for scan_id, recos in study._pvobj.avail_reco_id.items():
                    method = study._pvobj._method[scan_id].parameters['Method']
                    if re.search('epi', method, re.IGNORECASE) and not re.search('dti', method, re.IGNORECASE):
                        output_path = os.path.join(sess_path, 'func')
                    elif re.search('dti', method, re.IGNORECASE):
                        output_path = os.path.join(sess_path, 'dwi')
                    elif re.search('flash', method, re.IGNORECASE) or re.search('rare', method, re.IGNORECASE):
                        output_path = os.path.join(sess_path, 'anat')
                    else:
                        output_path = os.path.join(sess_path, 'etc')
                    try:
                        os.mkdir(output_path)
                    except:
                        pass
                    filename = 'sub-{}_ses-{}_{}'.format(study._pvobj.subj_id, study._pvobj.study_id,
                                                         str(scan_id).zfill(2))
                    for reco_id in recos:
                        output_fname = os.path.join(output_path, '{}_reco-{}'.format(filename,
                                                                                     str(reco_id).zfill(2)))
                        try:
                            study.save_as(scan_id, reco_id, output_fname)
                            if args.bids:
                                study.save_json(scan_id, reco_id, output_fname)
                            if re.search('dti', method, re.IGNORECASE):
                                study.save_bdata(scan_id, reco_id, output_fname)
                        except Exception as e:
                            print(e)
                print('{} is converted...'.format(raw))
            else:
                print('{} is empty...'.format(raw))

    else:
        parser.print_help()

if __name__ == '__main__':
    main()