jax
jaxopt
numpy
matplotlib