Source code for WORC.external.ComBatHarmonization.Python.neuroCombat.neuroCombat.neuroCombat

# Originally written by Nick Cullen
# Extended and currently maintained by JP Fortin
from __future__ import absolute_import, print_function
import pandas as pd
import numpy as np
import numpy.linalg as la
import math

[docs]def neuroCombat(dat, covars, batch_col, categorical_cols=None, continuous_cols=None, eb=True, parametric=True, mean_only=False): """ Run ComBat to remove scanner effects in multi-site imaging data Arguments --------- dat : a pandas data frame or numpy array neuroimaging data to correct with shape = (features, samples) e.g. cortical thickness measurements, image voxels, etc covars : a pandas data frame w/ shape = (samples, features) demographic/phenotypic/behavioral/batch data batch_col : string indicating batch (scanner) variable in covars categorical_cols : string or list of strings indicating categorical variables to adjust for - e.g. male or female continuous_cols : string or list of strings indicating continuous variables to adjust for - e.g. age eb : should Empirical Bayes be performed? - True by default parametric : should parametric adjustements be performed? - True by default mean_only : should only be the mean adjusted (no scaling)? - False by default Returns ------- - A numpy array with the same shape as `dat` which has now been ComBat-harmonized """ ############################## ### CLEANING UP INPUT DATA ### ############################## if not isinstance(covars, pd.DataFrame): raise ValueError('covars must be pandas datafraae -> try: covars = pandas.DataFrame(covars)') if not isinstance(categorical_cols, (list,tuple)): if categorical_cols is None: categorical_cols = [] else: categorical_cols = [categorical_cols] if not isinstance(continuous_cols, (list,tuple)): if continuous_cols is None: continuous_cols = [] else: continuous_cols = [continuous_cols] covar_labels = np.array(covars.columns) covars = np.array(covars, dtype='object') for i in range(covars.shape[-1]): try: covars[:,i] = covars[:,i].astype('float32') except: pass if isinstance(dat, pd.DataFrame): dat = np.array(dat, dtype='float32') #dat = dat.T # transpose data to make it (features, samples)... a weird genetics convention.. ############################## # get column indices for relevant variables batch_col = np.where(covar_labels==batch_col)[0][0] cat_cols = [np.where(covar_labels==c_var)[0][0] for c_var in categorical_cols] num_cols = [np.where(covar_labels==n_var)[0][0] for n_var in continuous_cols] # conver batch col to integer covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1] # create dictionary that stores batch info (batch_levels, sample_per_batch) = np.unique(covars[:,batch_col],return_counts=True) info_dict = { 'batch_levels': batch_levels.astype('int'), 'n_batch': len(batch_levels), 'n_sample': int(covars.shape[0]), 'sample_per_batch': sample_per_batch.astype('int'), 'batch_info': [list(np.where(covars[:,batch_col]==idx)[0]) for idx in batch_levels] } # create design matrix print('[neuroCombat] Creating design matrix') design = make_design_matrix(covars, batch_col, cat_cols, num_cols) # standardize data across features print('[neuroCombat] Standardizing data across features') s_data, s_mean, v_pool = standardize_across_features(dat, design, info_dict) # fit L/S models and find priors print('[neuroCombat] Fitting L/S model and finding priors') LS_dict = fit_LS_model_and_find_priors(s_data, design, info_dict, mean_only) # find parametric adjustments if eb: if parametric: print('[neuroCombat] Finding parametric adjustments') gamma_star, delta_star = find_parametric_adjustments(s_data, LS_dict, info_dict, mean_only) else: print('[neuroCombat] Finding non-parametric adjustments') gamma_star, delta_star = find_non_parametric_adjustments(s_data, LS_dict, info_dict, mean_only) else: print('[neuroCombat] Finding L/S adjustments without Empirical Bayes') gamma_star, delta_star = find_non_eb_adjustments(s_data, LS_dict, info_dict) # adjust data print('[neuroCombat] Final adjustment of data') bayes_data = adjust_data_final(s_data, design, gamma_star, delta_star, s_mean, v_pool, info_dict) bayes_data = np.array(bayes_data) return bayes_data
[docs]def make_design_matrix(Y, batch_col, cat_cols, num_cols): """ Return Matrix containing the following parts: - one-hot matrix of batch variable (full) - one-hot matrix for each categorical_cols (removing the first column) - column for each continuous_cols """ def to_categorical(y, nb_classes=None): if not nb_classes: nb_classes = np.max(y)+1 Y = np.zeros((len(y), nb_classes)) for i in range(len(y)): Y[i, y[i]] = 1. return Y hstack_list = [] ### batch one-hot ### # convert batch column to integer in case it's string batch = np.unique(Y[:,batch_col],return_inverse=True)[-1] batch_onehot = to_categorical(batch, len(np.unique(batch))) hstack_list.append(batch_onehot) ### categorical one-hots ### for cat_col in cat_cols: cat = np.unique(np.array(Y[:,cat_col]),return_inverse=True)[1] cat_onehot = to_categorical(cat, len(np.unique(cat)))[:,1:] hstack_list.append(cat_onehot) ### numerical vectors ### for num_col in num_cols: num = np.array(Y[:,num_col],dtype='float32') num = num.reshape(num.shape[0],1) hstack_list.append(num) design = np.hstack(hstack_list) return design
[docs]def standardize_across_features(X, design, info_dict): n_batch = info_dict['n_batch'] n_sample = info_dict['n_sample'] sample_per_batch = info_dict['sample_per_batch'] B_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), X.T) grand_mean = np.dot((sample_per_batch/ float(n_sample)).T, B_hat[:n_batch,:]) var_pooled = np.dot(((X - np.dot(design, B_hat).T)**2), np.ones((n_sample, 1)) / float(n_sample)) stand_mean = np.dot(grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, n_sample))) tmp = np.array(design.copy()) tmp[:,:n_batch] = 0 stand_mean += np.dot(tmp, B_hat).T s_data = ((X- stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, n_sample)))) return s_data, stand_mean, var_pooled
[docs]def aprior(delta_hat): m = np.mean(delta_hat) s2 = np.var(delta_hat,ddof=1) return (2 * s2 +m**2) / float(s2)
[docs]def bprior(delta_hat): m = delta_hat.mean() s2 = np.var(delta_hat,ddof=1) return (m*s2+m**3)/s2
[docs]def postmean(g_hat, g_bar, n, d_star, t2): return (t2*n*g_hat+d_star * g_bar) / (t2*n+d_star)
[docs]def postvar(sum2, n, a, b): return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
[docs]def fit_LS_model_and_find_priors(s_data, design, info_dict, mean_only): n_batch = info_dict['n_batch'] batch_info = info_dict['batch_info'] batch_design = design[:,:n_batch] gamma_hat = np.dot(np.dot(la.inv(np.dot(batch_design.T, batch_design)), batch_design.T), s_data.T) delta_hat = [] for i, batch_idxs in enumerate(batch_info): if mean_only: delta_hat.append(np.repeat(1, s_data.shape[0])) else: delta_hat.append(np.var(s_data[:,batch_idxs],axis=1,ddof=1)) gamma_bar = np.mean(gamma_hat, axis=1) t2 = np.var(gamma_hat,axis=1, ddof=1) if mean_only: a_prior = None b_prior = None else: a_prior = list(map(aprior, delta_hat)) b_prior = list(map(bprior, delta_hat)) LS_dict = {} LS_dict['gamma_hat'] = gamma_hat LS_dict['delta_hat'] = delta_hat LS_dict['gamma_bar'] = gamma_bar LS_dict['t2'] = t2 LS_dict['a_prior'] = a_prior LS_dict['b_prior'] = b_prior return LS_dict
#Helper function for parametric adjustements:
[docs]def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001): n = (1 - np.isnan(sdat)).sum(axis=1) g_old = g_hat.copy() d_old = d_hat.copy() change = 1 count = 0 while change > conv: g_new = postmean(g_hat, g_bar, n, d_old, t2) sum2 = ((sdat - np.dot(g_new.reshape((g_new.shape[0], 1)), np.ones((1, sdat.shape[1])))) ** 2).sum(axis=1) d_new = postvar(sum2, n, a, b) change = max((abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max()) g_old = g_new #.copy() d_old = d_new #.copy() count = count + 1 adjust = (g_new, d_new) return adjust
#Helper function for non-parametric adjustements:
[docs]def int_eprior(sdat, g_hat, d_hat): r = sdat.shape[0] gamma_star, delta_star = [], [] for i in range(0,r,1): g = np.delete(g_hat,i) d = np.delete(d_hat,i) x = sdat[i,:] n = x.shape[0] j = np.repeat(1,n) A = np.repeat(x, g.shape[0]) A = A.reshape(n,g.shape[0]) A = np.transpose(A) B = np.repeat(g, n) B = B.reshape(g.shape[0],n) resid2 = np.square(A-B) sum2 = resid2.dot(j) LH = 1/(2*math.pi*d)**(n/2)*np.exp(-sum2/(2*d)) LH = np.nan_to_num(LH) gamma_star.append(sum(g*LH)/sum(LH)) delta_star.append(sum(d*LH)/sum(LH)) adjust = (gamma_star, delta_star) return adjust
[docs]def find_parametric_adjustments(s_data, LS, info_dict, mean_only): batch_info = info_dict['batch_info'] gamma_star, delta_star = [], [] for i, batch_idxs in enumerate(batch_info): if mean_only: gamma_star.append(postmean(LS['gamma_hat'][i], LS['gamma_bar'][i], 1, 1, LS['t2'][i])) delta_star.append(np.repeat(1, s_data.shape[0])) else: temp = it_sol(s_data[:,batch_idxs], LS['gamma_hat'][i], LS['delta_hat'][i], LS['gamma_bar'][i], LS['t2'][i], LS['a_prior'][i], LS['b_prior'][i]) gamma_star.append(temp[0]) delta_star.append(temp[1]) return np.array(gamma_star), np.array(delta_star)
[docs]def find_non_parametric_adjustments(s_data, LS, info_dict, mean_only): batch_info = info_dict['batch_info'] gamma_star, delta_star = [], [] for i, batch_idxs in enumerate(batch_info): if mean_only: LS['delta_hat'][i] = np.repeat(1, s_data.shape[0]) temp = int_eprior(s_data[:,batch_idxs], LS['gamma_hat'][i], LS['delta_hat'][i]) gamma_star.append(temp[0]) delta_star.append(temp[1]) return np.array(gamma_star), np.array(delta_star)
[docs]def find_non_eb_adjustments(s_data, LS, info_dict): return LS['gamma_hat'], LS['delta_hat']
[docs]def adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_pooled, info_dict): sample_per_batch = info_dict['sample_per_batch'] n_batch = info_dict['n_batch'] n_sample = info_dict['n_sample'] batch_info = info_dict['batch_info'] batch_design = design[:,:n_batch] bayesdata = s_data gamma_star = np.array(gamma_star) delta_star = np.array(delta_star) for j, batch_idxs in enumerate(batch_info): dsq = np.sqrt(delta_star[j,:]) dsq = dsq.reshape((len(dsq), 1)) denom = np.dot(dsq, np.ones((1, sample_per_batch[j]))) numer = np.array(bayesdata[:,batch_idxs] - np.dot(batch_design[batch_idxs,:], gamma_star).T) bayesdata[:,batch_idxs] = numer / denom vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1)) bayesdata = bayesdata * np.dot(vpsq, np.ones((1, n_sample))) + stand_mean return bayesdata