jaxlib
jax
matplotlib==3.7.2
numpy==1.24.4
pandas==2.0.3
scikit_learn==1.3.0
scipy==1.12.0
setuptools==68.1.2
torch>=1.13.1
pyyaml

[cuda]
jax[cuda]

[dev]
pytest
mkdocs
mkdocs-material
mkdocstrings-python
mike
