numpy
typing_extensions
saiunit==0.1.3

[cpu]
jax[cpu]

[cuda12]
jax[cuda12]

[cuda13]
jax[cuda13]

[testing]
pytest

[tpu]
jax[tpu]
