#!python
# Create a table of key BOLD fMRI parameters for a BIDS dataset
#
# AUTHOR : Mike Tyszka
# PLACE  : Caltech
# DATES  : 2021-05-06 JMT From scratch
#          2024-05-03 JMT Fix handling of missing fields
#          2024-07-11 JMT Expand from just BOLD series to all series types

import os.path as op
import sys
import argparse
import json
import bids
import pandas as pd
import nibabel as nib
from tabulate import tabulate
from glob import glob


def main():

    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Create table of key BOLD fMRI parameters for a BIDS dataset')
    parser.add_argument('-d', '--dataset', default='.', help='BIDS dataset directory')
    args = parser.parse_args()

    bids_dir = op.realpath(args.dataset)

    print('')
    print('BIDS Simple Parameter Table')
    print('-------------------------')
    print('BIDS dataset directory : {}'.format(bids_dir))

    print('Searching for series metadata (JSON sidecars)')
    json_list = sorted(glob(op.join(bids_dir, 'sub-*', 'ses-*', '*', '*.json')))

    if len(json_list) < 1:
        print('* No metadata found - exiting')
        sys.exit(1)

    meta_list = []
    
    for json_path in json_list:

        json_dir = op.dirname(json_path)
        json_fname = op.basename(json_path)
        nii_fname = json_fname.replace('.json', '.nii.gz')
        nii_path = op.join(json_dir, nii_fname)

        print('  Processing {}'.format(nii_fname))

        # Load Nifti header
        nii = nib.load(nii_path)
        hdr = nii.header

        # Image dimensions
        dim = hdr['dim']
        vox = hdr['pixdim']

        # Parse BIDS filename
        keys = bids.layout.parse_file_entities(json_path)

        # Load JSON metadata
        with open(json_path, 'r') as fd:
            meta = json.load(fd)

        # Original series number and acquisition time
        ser_no = meta['SeriesNumber']
        acq_time = meta['AcquisitionTime']

        # In-plane GRAPPA factor if present
        if 'ParallelReductionFactorInPlane' in meta:
            r = meta['ParallelReductionFactorInPlane']
        else:
            r = 1

        #  Multiband acceleration factor if present
        if 'MultibandAccelerationFactor' in meta:
            m = meta['MultibandAccelerationFactor']
        else:
            m = 1

        # Handle commonly absent fields
        if not 'session' in keys:
            keys['session'] = '-'
        if not 'task' in keys:
            keys['task'] = '-'
        if not 'run' in keys:
            keys['run'] = 1

        # Add to running list
        meta_list.append([
            ser_no,
            acq_time,
            nii_fname,
            keys['subject'],
            keys['session'],
            keys['task'],
            keys['run'],
            meta['RepetitionTime'],
            meta['EchoTime'],
            meta['FlipAngle'],
            r, m,
            dim[1], dim[2], dim[3], dim[4],
            vox[1], vox[2], vox[3]
        ])
               
    # Convert to dataframe and save as CSV
    df = pd.DataFrame(
        meta_list,
        columns=[
            'SeriesNumber',
            'AcquisitionTime',
            'Filename',
            'Subject',
            'Session',
            'Task',
            'Run',
            'TR_secs',
            'TE_secs',
            'FlipAngle_degs',
            'R', 'M',
            'nx', 'ny', 'nz', 'nt',
            'vx_mm', 'vy_mm', 'vz_mm'
        ]
    )

    # Set row index to SeriesNumber
    df.sort_values(by='SeriesNumber', inplace=True)

    csv_fname = op.join(bids_dir, 'bidsdump.csv')
    print('')
    print('Saving BIDS BOLD info to {}'.format(csv_fname))
    df.to_csv(csv_fname, index=False)

    # Pretty print dataframe as a table to stdout
    print('')
    print(tabulate(df, headers='keys', tablefmt='github', stralign='left', numalign='decimal', showindex=False))


if "__main__" in __name__:

    main()
