jax>=0.4.20
jaxlib>=0.4.20
tensorflow-probability>=0.23.0
matplotlib>=3.8.0
optax>=0.1.7
