Learning JAX: JAX and Flax NNX Tutorials
This is a complete collection of JAX and Flax NNX tutorials. Work through them to build high-performance ML models with the JAX ecosystem.
The material comes from Google’s Learning-JAX project and covers topics from fundamentals to advanced workflows.
✨ JAX Tutorial Series
About JAX
JAX is a Python library for high-performance numerical computing, ideal for machine learning research. It provides:
- Automatic differentiation via
jax.grad
- JIT compilation via
jax.jit
- Automatic vectorization via
jax.vmap
- Parallelism with straightforward distributed training
Key Components
- JAX: core numerical library
- Flax NNX: modern neural network library
- Optax: optimizer library
- Orbax: checkpointing library
- Grain: data loading toolkit
- Chex: testing and reliability tools
Resources