numpy
jax
jaxopt
optax
scipy>=1.15
tqdm
