jax
jaxlib
tensorflow_probability
jaxns==2.4.4
pydantic
chex>=0.0.8
mctx
pyDOE2
matplotlib
etils
