numpy>=1.15
jax>=0.2.10
tqdm