torch>=2.2
numpy<2,>=1.24.4
jaxtyping>=0.2.34
pot>=0.9.5
scikit-learn>=1.6.0
matplotlib>=3.3.2
