[](https://colab.research.google.com/github/google/flax/blob/main/docsnnx/mnisttutorial.ipynb) ![在 GitHub 上打开](https://github.com/google/flax/blob/main/docsnnx/mnisttutorial.ipynb)

MNIST 教程

欢迎来到 Flax NNX!在本教程中,您将学习如何使用 Flax NNX API 构建和训练一个简单的卷积神经网络 (CNN) 来对 MNIST 数据集上的手写数字进行分类。

Flax NNX 是一个基于 JAX 构建的 Python 神经网络库。如果您以前使用过 Flax Linen API,请查看 为什么选择 Flax NNX。您应该对深度学习的主要概念有一定的了解。

让我们开始吧!

1. 安装 JAX AI 技术栈

!pip install -q jax-ai-stack==2025.9.3

2. 加载 MNIST 数据集

首先,您需要加载 MNIST 数据集,然后通过 Tensorflow Datasets (TFDS) 准备训练集和测试集。您需要对图像值进行归一化、对数据进行混洗并将其分成批次,并预取样本以提高性能。

import tensorflow_datasets as tfds  # TFDS 用于下载 MNIST。
import tensorflow as tf  # TensorFlow / `tf.data` 操作。

tf.random.set_seed(0)  # 设置随机种子以保证可复现性。

train_steps = 1200
eval_every = 200
batch_size = 32

train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # 归一化训练集

test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # 归一化测试集

# 通过分配一个 1024 的缓冲区大小来创建一个混洗的数据集,从中随机抽取元素。
train_ds = train_ds.repeat().shuffle(1024)
# 分成 `batch_size` 大小的批次,跳过不完整的批次,预取下一个样本以提高延迟。
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# 分成 `batch_size` 大小的批次,跳过不完整的批次,预取下一个样本以提高延迟。
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

3. 使用 Flax NNX 定义模型

通过子类化 nnx.Module 来创建一个用于分类的 CNN:

from flax import nnx  # Flax NNX API。
from functools import partial

class CNN(nnx.Module):
  """一个简单的 CNN 模型。"""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # 展平
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

# 实例化模型。
model = CNN(rngs=nnx.Rngs(0))
# 可视化它。
nnx.display(model)

运行模型

让我们来测试一下 CNN 模型!在这里,您将使用任意数据执行前向传播并打印结果。

import jax.numpy as jnp  # JAX NumPy

y = model(jnp.ones((1, 28, 28, 1)))
y

4. 创建优化器并定义一些指标

在 Flax NNX 中,您需要创建一个 nnx.Optimizer 对象来管理模型的参数并在训练期间应用梯度。nnx.Optimizer 使用模型进行初始化以推断优化器状态的结构,并使用 Optax 优化器来定义更新规则。

import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum), wrt=nnx.Param)
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)

5. 定义训练步骤函数

在本节中,您将使用交叉熵损失 (optax.softmax_cross_entropy_with_integer_labels()) 定义一个损失函数,CNN 模型将对此进行优化。

除了 loss 之外,在训练和测试期间,您还将获得 logits,它将用于计算准确率指标。

在训练期间 - train_step - 您将使用 nnx.value_and_grad 计算梯度,并使用您已经定义的 optimizer 更新模型的参数。在训练和测试(eval_step)期间,losslogits 都将用于计算指标。

def loss_fn(model: CNN, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """训练一个步骤。"""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # 就地更新。
  optimizer.update(model, grads)  # 就地更新。

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # 就地更新。

在上面的代码中,nnx.jit 转换装饰器跟踪 train_step 函数,以便使用 XLA 进行即时编译,从而优化硬件加速器(如 Google TPU 和 GPU)上的性能。nnx.jitjax.jit 转换的“提升”版本,允许其函数输入和输出是 Flax NNX 对象。同样,nnx.value_and_gradjax.value_and_grad 的提升版本。查看 提升的转换指南 以了解更多信息。

> 注意: 代码显示了如何对模型、优化器和指标执行多个就地更新,但并未显式返回状态更新。这是因为 Flax NNX 转换尊重 Flax NNX 对象的引用语义,并将传播作为输入参数传递的对象的状更新。这是 Flax NNX 的一个关键特性,它使代码更简洁、更具可读性。您可以在 为什么选择 Flax NNX 中了解更多信息。

6. 训练和评估模型

现在,您可以使用批处理数据训练 CNN 模型 10 个周期,在每个周期后评估模型在测试集上的性能,并在此过程中记录训练和测试指标(损失和准确率)。通常,这会使模型达到约 99% 的准确率。

from IPython.display import clear_output
import matplotlib.pyplot as plt

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # 运行优化一个步骤,并对以下内容进行有状态的更新:
  # - 训练状态的模型参数
  # - 优化器状态
  # - 训练损失和准确率批次指标
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # 一个训练周期已过。
    # 记录训练指标。
    for metric, value in metrics.compute().items():  # 计算指标。
      metrics_history[f'train_{metric}'].append(value)  # 记录指标。
    metrics.reset()  # 为测试集重置指标。

    # 每个训练周期后计算测试集上的指标。
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # 记录测试指标。
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # 为下一个训练周期重置指标。

    clear_output(wait=True)
    # 在子图中绘制损失和准确率
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.set_title('损失')
    ax2.set_title('准确率')
    for dataset in ('train', 'test'):
      ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
      ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
    ax1.legend()
    ax2.legend()
    plt.show()

7. 在测试集上执行推理

创建一个 jit 编译的模型推理函数 (使用 nnx.jit) - pred_step - 使用学习到的模型参数在测试集上生成预测。这将使您能够将测试图像与其预测标签一起可视化,以便对模型性能进行定性评估。

model.eval() # 切换到评估模式。

@nnx.jit
def pred_step(model: CNN, batch):
  logits = model(batch['image'])
  return logits.argmax(axis=1)

请注意,我们使用 .eval() 来确保模型处于评估模式,即使我们在此模型中没有使用 DropoutBatchNorm.eval() 也能确保输出是确定性的。

test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(model, test_batch)

fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
  ax.set_title(f'label={pred[i]}')
  ax.axis('off')

使用模型浏览器探索您的模型

要真正深入了解模型并了解其操作和连接,模型浏览器是一个很棒的工具!现在让我们来看看我们的 MNIST 模型。请随意四处看看并探索模型!

# 安装模型浏览器

!pip install --no-deps ai-edge-model-explorer-adapter ai-edge-model-explorer
# 使用一些虚拟输入并将模型 MLIR 写入文件

import jax
dummy_input = jnp.ones((1, 28, 28, 1))
stablehlo_mlir = jax.jit(model).lower(dummy_input).as_text(debug_info=True)
mlir_file = open("stablehlo_mlir.mlir", "w")
mlir_file.write(stablehlo_mlir)
mlir_file.close()
# 导入并使用模型浏览器运行模型

import model_explorer

model_explorer.visualize("stablehlo_mlir.mlir")

恭喜!您已经学会了如何使用 Flax NNX 在 MNIST 数据集上端到端地构建和训练一个简单的分类模型。

接下来,请查看 为什么选择 Flax NNX? 并开始学习一系列 Flax NNX 指南