[](https://colab.research.google.com/github/google/flax/blob/main/docsnnx/mnisttutorial.ipynb) 
欢迎来到 Flax NNX!在本教程中,您将学习如何使用 Flax NNX API 构建和训练一个简单的卷积神经网络 (CNN) 来对 MNIST 数据集上的手写数字进行分类。
Flax NNX 是一个基于 JAX 构建的 Python 神经网络库。如果您以前使用过 Flax Linen API,请查看 为什么选择 Flax NNX。您应该对深度学习的主要概念有一定的了解。
让我们开始吧!
!pip install -q jax-ai-stack==2025.9.3
首先,您需要加载 MNIST 数据集,然后通过 Tensorflow Datasets (TFDS) 准备训练集和测试集。您需要对图像值进行归一化、对数据进行混洗并将其分成批次,并预取样本以提高性能。
import tensorflow_datasets as tfds # TFDS 用于下载 MNIST。
import tensorflow as tf # TensorFlow / `tf.data` 操作。
tf.random.set_seed(0) # 设置随机种子以保证可复现性。
train_steps = 1200
eval_every = 200
batch_size = 32
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
train_ds = train_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # 归一化训练集
test_ds = test_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # 归一化测试集
# 通过分配一个 1024 的缓冲区大小来创建一个混洗的数据集,从中随机抽取元素。
train_ds = train_ds.repeat().shuffle(1024)
# 分成 `batch_size` 大小的批次,跳过不完整的批次,预取下一个样本以提高延迟。
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# 分成 `batch_size` 大小的批次,跳过不完整的批次,预取下一个样本以提高延迟。
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
from flax import nnx # Flax NNX API。
from functools import partial
class CNN(nnx.Module):
"""一个简单的 CNN 模型。"""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # 展平
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
# 实例化模型。
model = CNN(rngs=nnx.Rngs(0))
# 可视化它。
nnx.display(model)
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
y
在 Flax NNX 中,您需要创建一个 nnx.Optimizer 对象来管理模型的参数并在训练期间应用梯度。nnx.Optimizer 使用模型进行初始化以推断优化器状态的结构,并使用 Optax 优化器来定义更新规则。
import optax
learning_rate = 0.005
momentum = 0.9
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum), wrt=nnx.Param)
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)
nnx.display(optimizer)
在本节中,您将使用交叉熵损失 (optax.softmax_cross_entropy_with_integer_labels()) 定义一个损失函数,CNN 模型将对此进行优化。
除了 loss 之外,在训练和测试期间,您还将获得 logits,它将用于计算准确率指标。
在训练期间 - train_step - 您将使用 nnx.value_and_grad 计算梯度,并使用您已经定义的 optimizer 更新模型的参数。在训练和测试(eval_step)期间,loss 和 logits 都将用于计算指标。
def loss_fn(model: CNN, batch):
logits = model(batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits
@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""训练一个步骤。"""
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label']) # 就地更新。
optimizer.update(model, grads) # 就地更新。
@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
loss, logits = loss_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label']) # 就地更新。
在上面的代码中,nnx.jit 转换装饰器跟踪 train_step 函数,以便使用 XLA 进行即时编译,从而优化硬件加速器(如 Google TPU 和 GPU)上的性能。nnx.jit 是 jax.jit 转换的“提升”版本,允许其函数输入和输出是 Flax NNX 对象。同样,nnx.value_and_grad 是 jax.value_and_grad 的提升版本。查看 提升的转换指南 以了解更多信息。
> 注意: 代码显示了如何对模型、优化器和指标执行多个就地更新,但并未显式返回状态更新。这是因为 Flax NNX 转换尊重 Flax NNX 对象的引用语义,并将传播作为输入参数传递的对象的状更新。这是 Flax NNX 的一个关键特性,它使代码更简洁、更具可读性。您可以在 为什么选择 Flax NNX 中了解更多信息。
现在,您可以使用批处理数据训练 CNN 模型 10 个周期,在每个周期后评估模型在测试集上的性能,并在此过程中记录训练和测试指标(损失和准确率)。通常,这会使模型达到约 99% 的准确率。
from IPython.display import clear_output
import matplotlib.pyplot as plt
metrics_history = {
'train_loss': [],
'train_accuracy': [],
'test_loss': [],
'test_accuracy': [],
}
for step, batch in enumerate(train_ds.as_numpy_iterator()):
# 运行优化一个步骤,并对以下内容进行有状态的更新:
# - 训练状态的模型参数
# - 优化器状态
# - 训练损失和准确率批次指标
train_step(model, optimizer, metrics, batch)
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # 一个训练周期已过。
# 记录训练指标。
for metric, value in metrics.compute().items(): # 计算指标。
metrics_history[f'train_{metric}'].append(value) # 记录指标。
metrics.reset() # 为测试集重置指标。
# 每个训练周期后计算测试集上的指标。
for test_batch in test_ds.as_numpy_iterator():
eval_step(model, metrics, test_batch)
# 记录测试指标。
for metric, value in metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
metrics.reset() # 为下一个训练周期重置指标。
clear_output(wait=True)
# 在子图中绘制损失和准确率
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('损失')
ax2.set_title('准确率')
for dataset in ('train', 'test'):
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
创建一个 jit 编译的模型推理函数 (使用 nnx.jit) - pred_step - 使用学习到的模型参数在测试集上生成预测。这将使您能够将测试图像与其预测标签一起可视化,以便对模型性能进行定性评估。
model.eval() # 切换到评估模式。
@nnx.jit
def pred_step(model: CNN, batch):
logits = model(batch['image'])
return logits.argmax(axis=1)
请注意,我们使用 .eval() 来确保模型处于评估模式,即使我们在此模型中没有使用 Dropout 或 BatchNorm,.eval() 也能确保输出是确定性的。
test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(model, test_batch)
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
ax.set_title(f'label={pred[i]}')
ax.axis('off')
# 安装模型浏览器
!pip install --no-deps ai-edge-model-explorer-adapter ai-edge-model-explorer
# 使用一些虚拟输入并将模型 MLIR 写入文件
import jax
dummy_input = jnp.ones((1, 28, 28, 1))
stablehlo_mlir = jax.jit(model).lower(dummy_input).as_text(debug_info=True)
mlir_file = open("stablehlo_mlir.mlir", "w")
mlir_file.write(stablehlo_mlir)
mlir_file.close()
# 导入并使用模型浏览器运行模型
import model_explorer
model_explorer.visualize("stablehlo_mlir.mlir")
恭喜!您已经学会了如何使用 Flax NNX 在 MNIST 数据集上端到端地构建和训练一个简单的分类模型。
接下来,请查看 为什么选择 Flax NNX? 并开始学习一系列 Flax NNX 指南。