Learning JAX: JAX 和 Flax NNX 教程集合

这是一个完整的 JAX 和 Flax NNX 学习教程集合。通过这些教程,你将学习如何使用 JAX 生态系统构建高性能的机器学习模型。

这些教程基于 Google 的 Learning-JAX 项目,包含了从基础到高级的各个主题。

✨ JAX 教程系列

关于 JAX

JAX 是一个用于高性能数值计算的 Python 库,特别适合机器学习研究。它提供了:

  • 自动微分:通过 jax.grad 轻松计算梯度
  • JIT 编译:通过 jax.jit 加速代码执行
  • 自动向量化:通过 jax.vmap 批处理操作
  • 并行化:轻松实现分布式训练

主要组件

  • JAX:核心数值计算库
  • Flax NNX:现代化的神经网络库
  • Optax:优化器库
  • Orbax:检查点保存和恢复
  • Grain:数据加载工具
  • Chex:测试和可靠性工具

资源链接