jax>=0.1
numpy>=1.22
einops>=0.3
