-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # seems like jax install doesn't work right without this...
jax[cuda]==0.3.25 # the SCC's jax module is 0.2.19, so let's get at least 0.3. Also, version >4 so far seems to have some weird issue in convolutons on GPU
einops
wandb
tqdm
transformers
datasets
omegaconf
dill

