jax
jaxopt>=0.6
typing-extensions>=4.5.0
wandb>=0.14.2
pandas>=2.0.1
plotly>=5.14.1
flax>=0.6.10
tensorflow_datasets
