jax-ai-stack
packages:jax==0.5.0
↗️ chex==0.1.88
🆕flax==0.10.2
ml_dtypes==0.4.0
optax==0.2.4
orbax-checkpoint==0.11.1
↗️ orbax-export==0.0.6
jax-ai-stack[tfds]
packages:tensorflow==2.18.0
tensorflow_datasets==4.9.7
jax-ai-stack[grain]
packages:grain==0.2.3
Notes
This version of jax-ai-stack
adds chex
as a direct pinned dependency.
Prior to this release chex
was an indirect unpinned dependency via optax
.