Flax NNX 简介

欢迎来到 Flax NNX Colab 笔记本!本笔记本提供动手练习,旨在帮助 PyTorch 用户过渡到 Flax NNX 和 JAX 生态系统。

我们将涵盖核心概念并构建简单的模型。

!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
from typing import Any, Dict, Tuple # 用于类型提示

print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}")
# @title 练习 1: 理解模块和参数 (编程练习)

# 说明:
# 1. 创建一个名为 `MyLinearLayer` 的简单 NNX 模块。
# 2. 它应该有一个名为 `weight` 的 `nnx.Param` (使用形状 [input_size, output_size] 随机初始化)。
# 3. 它应该有一个名为 `bias` 的 `nnx.Param` (使用形状 [output_size] 初始化为零)。
# 4. 前向传播 (`__call__` 方法) 应执行线性变换: `x @ self.weight.value + self.bias.value`。
# 5. 使用 `input_size=10` 和 `output_size=5` 实例化该层。
# 6. 打印 `weight` 和 `bias` 参数的形状。

class MyLinearLayer(nnx.Module):
    def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):

        pass # 填写此部分

    def __call__(self, x: jax.Array):
        pass # 填写此部分

# 实例化该层
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(
    input_size='填写此部分',
    output_size='填写此部分',
    rngs=nnx.Rngs(key))

# 打印参数的形状
print("权重形状:", '填写此部分')
print("偏置形状:", '填写此部分')

# 示例用法:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("输出形状:", output.shape)
# @title 练习 1 解决方案

class MyLinearLayer(nnx.Module):
    def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
        self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))
        self.bias = nnx.Param(jnp.zeros((output_size,)))

    def __call__(self, x: jax.Array):
        return x @ self.weight.value + self.bias.value

# 实例化该层
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))

# 打印参数的形状
print("权重形状:", linear_layer.weight.value.shape)
print("偏置形状:", linear_layer.bias.value.shape)

# 示例用法:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("输出形状:", output.shape)
# @title 练习 2: 状态管理 (编程练习)
# 说明:
# 1. 创建一个名为 `CounterModule` 的 NNX 模块。
# 2. 它应该有一个名为 `count` 的 Python 实例属性,初始化为 0。
# 3. `__call__` 方法应该将 `count` 增加 1 并返回新值。
# 4. 实例化该模块。
# 5. 多次调用该模块并打印返回的值。
# 6. 使用 `nnx.split` 和 `nnx.merge` 保存和加载模块的状态。验证计数器从它离开的地方继续。

class CounterModule(nnx.Module):
    def __init__(self):
        pass # 填写此部分

    def __call__(self):
        pass # 填写此部分

# 实例化该模块
pass # 填写此部分。将其命名为“counter”

# 调用该模块并打印值
print("第一次调用:", counter())
print("第二次调用:", counter())

# 将模块拆分为 graphdef 和 state。
# 请记住,state 是一个 nnx.Variable
graphdef, state =  # 填写此部分

# 合并 graphdef 和 state 以创建一个新模块
new_counter =  # 填写此部分

# 调用新模块并打印值
print("拆分和合并后,第一次调用:", new_counter())
print("拆分和合并后,第二次调用:", new_counter())
# @title 练习 2 解决方案

class CounterModule(nnx.Module):
    def __init__(self):
        self.count = 0

    def __call__(self):
        self.count += 1
        return self.count

# 实例化该模块
counter = CounterModule()

# 调用该模块并打印值
print("第一次调用:", counter())
print("第二次调用:", counter())

# 将模块拆分为 graphdef 和 state
graphdef, state = nnx.split(counter, nnx.Variable)

# 合并 graphdef 和 state 以创建一个新模块
new_counter = nnx.merge(graphdef, state)

# 调用新模块并打印值
print("拆分和合并后,第一次调用:", new_counter())
print("拆分和合并后,第二次调用:", new_counter())
# @title 练习 3: 显式随机数生成 (编程练习)

# 说明:
# 1. 创建一个名为 `RandomNormalLayer` 的 NNX 模块。
# 2. 其 `__init__` 方法应接收一个 `size` 参数,用于定义要生成的随机向量的大小。
# 3. `__init__` 方法应接收一个 `rngs: nnx.Rngs` 参数,该参数用于使用 `jax.random.normal` 生成一个随机正态张量,并将该张量分配给 `self.random_vector`。
# 4. `__call__` 方法应返回 `self.random_vector` 的值 (一个新的随机正态张量)。
# 5. 实例化大小为 10 的层,并传入带有 `jax.random.PRNGKey` 的 `rngs` 参数。
# 6. 多次调用该模块,并观察返回的值是否不同。

# 创建 RandomNormalLayer

# 实例化该模块
key = # 使用 jax.random.PRNGKey 创建一个新密钥
random_layer = RandomNormalLayer(size='此处填写大小', rngs=nnx.Rngs(key))

# 调用该模块并打印值
print("第一次调用:", random_layer())
print("第二次调用:", random_layer())
# @title 练习 3 解决方案

class RandomNormalLayer(nnx.Module):
   def __init__(self, size: int, *, rngs: nnx.Rngs):
     self.size = size
     self.rngs = rngs
   def __call__(self):
      self.random_vector = jax.random.normal(self.rngs.params(), (self.size,))
      return self.random_vector

# 实例化该模块
key = jax.random.PRNGKey(0)
random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))

# 调用该模块并打印值
print("第一次调用:", random_layer())
print("第二次调用:", random_layer())
# @title 练习 4: 构建一个简单的 CNN (编程练习)

# 说明:
# 1. 创建一个表示简单 CNN 的 NNX 模块,包含以下层:
#    - 卷积层 (nnx.Conv),具有 32 个滤波器,内核大小为 3,步幅为 1。
#    - ReLU 激活。
#    - 最大池化层 (nnx.max_pool),窗口大小为 2,步幅为 2。
#    - 展平层 (jax.numpy.reshape)。
#    - 线性层 (nnx.Linear),映射到 10 个输出类别。
# 2. 使用适当的输入和输出形状初始化 CNN。
# 3. 使用虚拟输入执行前向传播并打印输出形状。

class SimpleCNN(nnx.Module):
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        self.conv = nnx.Conv('步幅', '滤波器', kernel_size=('X, X'), rngs=rngs)
        self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.conv(x)
        print(f'{x.shape = }') # 用于调试
        x = nnx.relu(x)
        print(f'{x.shape = }') # 用于调试
        x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))
        print(f'{x.shape = }') # 用于调试
        x = x.reshape(x.shape[0], -1)  # 展平
        print(f'{x.shape = }') # 用于调试
        x = self.linear(x)
        return x

# 实例化 CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes='输出类别', rngs=nnx.Rngs(key))

# 虚拟输入
dummy_input = jnp.ones((1, 28, 28, 1))

# 前向传播
output = cnn(dummy_input)
print("输出形状:", output.shape)
# @title 练习 4 解决方案

class SimpleCNN(nnx.Module):
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.conv(x)
        print(f'{x.shape = }')
        x = nnx.relu(x)
        print(f'{x.shape = }')
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        print(f'{x.shape = }')
        x = x.reshape(x.shape[0], -1)  # 展平
        print(f'{x.shape = }')
        x = self.linear(x)
        return x

# 实例化 CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))

# 虚拟输入
dummy_input = jnp.ones((1, 28, 28, 1))

# 前向传播
output = cnn(dummy_input)
print("输出形状:", output.shape)
# @title 练习 5: 使用 Optax 的训练循环 (编程练习)

# 说明:
# 1. 定义一个简单的模型 (例如,一个线性层)。
# 2. 创建一个 nnx.Optimizer,确保使用现在必需的 wrt 参数指定要更新的变量类型 (例如,wrt=nnx.Param)。
# 3. 实现一个训练步骤函数,该函数:
#    - 计算损失 (例如,均方误差)。
#    - 使用 `nnx.value_and_grad` 计算梯度。
#    - 使用 `optimizer.update(model, grads)` 更新模型的状态。
# 4. 运行训练循环几个步骤。

# 定义一个简单的模型
class LinearModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = '此处为线性层'

    def __call__(self, x: jax.Array):
        return self.linear(x)

# 实例化该模型
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))

# 创建一个 Optax 优化器
tx = '此处为 OPTAX SGD'
optimizer = nnx.Optimizer('包装优化器')

# 虚拟数据
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])

# 训练步骤函数
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss, model

# 训练循环
num_steps = 10
for i in range(num_steps):
    loss, model = train_step(model, optimizer, x, y)
    print(f"步骤 {i+1}, 损失: {loss}")

print("训练后的模型输出:", model(x))
# @title 练习 5 解决方案

# 定义一个简单的模型
class LinearModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)

    def __call__(self, x: jax.Array):
        return self.linear(x)

# 实例化该模型
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))

# 创建一个 Optax 优化器
tx = optax.sgd(learning_rate=0.01)
optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)

# 虚拟数据
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])

# 训练步骤函数
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss, model

# 训练循环
num_steps = 10
for i in range(num_steps):
    loss, model = train_step(model, optimizer, x, y)
    print(f"步骤 {i+1}, 损失: {loss}")

print("训练后的模型输出:", model(x))

恭喜!

您现在已经完成了 Flax NNX 的基础知识学习!

请记住查阅官方文档以获取更深入的详细信息:

  • Flax NNX: (Flax 文档的一部分) https://flax.readthedocs.io
  • JAX: https://jax.readthedocs.io

继续练习,祝您 JAX 编码愉快!

请通过 https://goo.gle/jax-training-feedback 向我们发送反馈