tensorflow>=2.5.0
docker
nltk
numpy
pandas
jax
flax
optax
wandb
tensorflow-hub
tqdm
PyYAML
spacy
transformers
scipy
matplotlib
seaborn

[cloud_tpu]
cloud-tpu-client
google-auth

[dev]
pytest
pytest-cov
sphinx
sphinx-rtd-theme
myst-parser
black
flake8

[gpu]
tensorflow-gpu>=2.5.0

[tpu]
jax[tpu]>=0.2.21
flax>=0.3.4
optax>=0.0.9
