欢迎来到 Flax NNX Colab 笔记本!本笔记本提供动手练习,旨在帮助 PyTorch 用户过渡到 Flax NNX 和 JAX 生态系统。
我们将涵盖核心概念并构建简单的模型。
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
from typing import Any, Dict, Tuple # 用于类型提示
print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}")
# @title 练习 1: 理解模块和参数 (编程练习)
# 说明:
# 1. 创建一个名为 `MyLinearLayer` 的简单 NNX 模块。
# 2. 它应该有一个名为 `weight` 的 `nnx.Param` (使用形状 [input_size, output_size] 随机初始化)。
# 3. 它应该有一个名为 `bias` 的 `nnx.Param` (使用形状 [output_size] 初始化为零)。
# 4. 前向传播 (`__call__` 方法) 应执行线性变换: `x @ self.weight.value + self.bias.value`。
# 5. 使用 `input_size=10` 和 `output_size=5` 实例化该层。
# 6. 打印 `weight` 和 `bias` 参数的形状。
class MyLinearLayer(nnx.Module):
def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
pass # 填写此部分
def __call__(self, x: jax.Array):
pass # 填写此部分
# 实例化该层
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(
input_size='填写此部分',
output_size='填写此部分',
rngs=nnx.Rngs(key))
# 打印参数的形状
print("权重形状:", '填写此部分')
print("偏置形状:", '填写此部分')
# 示例用法:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("输出形状:", output.shape)
# @title 练习 1 解决方案
class MyLinearLayer(nnx.Module):
def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))
self.bias = nnx.Param(jnp.zeros((output_size,)))
def __call__(self, x: jax.Array):
return x @ self.weight.value + self.bias.value
# 实例化该层
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))
# 打印参数的形状
print("权重形状:", linear_layer.weight.value.shape)
print("偏置形状:", linear_layer.bias.value.shape)
# 示例用法:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("输出形状:", output.shape)
# @title 练习 2: 状态管理 (编程练习)
# 说明:
# 1. 创建一个名为 `CounterModule` 的 NNX 模块。
# 2. 它应该有一个名为 `count` 的 Python 实例属性,初始化为 0。
# 3. `__call__` 方法应该将 `count` 增加 1 并返回新值。
# 4. 实例化该模块。
# 5. 多次调用该模块并打印返回的值。
# 6. 使用 `nnx.split` 和 `nnx.merge` 保存和加载模块的状态。验证计数器从它离开的地方继续。
class CounterModule(nnx.Module):
def __init__(self):
pass # 填写此部分
def __call__(self):
pass # 填写此部分
# 实例化该模块
pass # 填写此部分。将其命名为“counter”
# 调用该模块并打印值
print("第一次调用:", counter())
print("第二次调用:", counter())
# 将模块拆分为 graphdef 和 state。
# 请记住,state 是一个 nnx.Variable
graphdef, state = # 填写此部分
# 合并 graphdef 和 state 以创建一个新模块
new_counter = # 填写此部分
# 调用新模块并打印值
print("拆分和合并后,第一次调用:", new_counter())
print("拆分和合并后,第二次调用:", new_counter())
# @title 练习 2 解决方案
class CounterModule(nnx.Module):
def __init__(self):
self.count = 0
def __call__(self):
self.count += 1
return self.count
# 实例化该模块
counter = CounterModule()
# 调用该模块并打印值
print("第一次调用:", counter())
print("第二次调用:", counter())
# 将模块拆分为 graphdef 和 state
graphdef, state = nnx.split(counter, nnx.Variable)
# 合并 graphdef 和 state 以创建一个新模块
new_counter = nnx.merge(graphdef, state)
# 调用新模块并打印值
print("拆分和合并后,第一次调用:", new_counter())
print("拆分和合并后,第二次调用:", new_counter())
# @title 练习 3: 显式随机数生成 (编程练习)
# 说明:
# 1. 创建一个名为 `RandomNormalLayer` 的 NNX 模块。
# 2. 其 `__init__` 方法应接收一个 `size` 参数,用于定义要生成的随机向量的大小。
# 3. `__init__` 方法应接收一个 `rngs: nnx.Rngs` 参数,该参数用于使用 `jax.random.normal` 生成一个随机正态张量,并将该张量分配给 `self.random_vector`。
# 4. `__call__` 方法应返回 `self.random_vector` 的值 (一个新的随机正态张量)。
# 5. 实例化大小为 10 的层,并传入带有 `jax.random.PRNGKey` 的 `rngs` 参数。
# 6. 多次调用该模块,并观察返回的值是否不同。
# 创建 RandomNormalLayer
# 实例化该模块
key = # 使用 jax.random.PRNGKey 创建一个新密钥
random_layer = RandomNormalLayer(size='此处填写大小', rngs=nnx.Rngs(key))
# 调用该模块并打印值
print("第一次调用:", random_layer())
print("第二次调用:", random_layer())
# @title 练习 3 解决方案
class RandomNormalLayer(nnx.Module):
def __init__(self, size: int, *, rngs: nnx.Rngs):
self.size = size
self.rngs = rngs
def __call__(self):
self.random_vector = jax.random.normal(self.rngs.params(), (self.size,))
return self.random_vector
# 实例化该模块
key = jax.random.PRNGKey(0)
random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))
# 调用该模块并打印值
print("第一次调用:", random_layer())
print("第二次调用:", random_layer())
# @title 练习 4: 构建一个简单的 CNN (编程练习)
# 说明:
# 1. 创建一个表示简单 CNN 的 NNX 模块,包含以下层:
# - 卷积层 (nnx.Conv),具有 32 个滤波器,内核大小为 3,步幅为 1。
# - ReLU 激活。
# - 最大池化层 (nnx.max_pool),窗口大小为 2,步幅为 2。
# - 展平层 (jax.numpy.reshape)。
# - 线性层 (nnx.Linear),映射到 10 个输出类别。
# 2. 使用适当的输入和输出形状初始化 CNN。
# 3. 使用虚拟输入执行前向传播并打印输出形状。
class SimpleCNN(nnx.Module):
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
self.conv = nnx.Conv('步幅', '滤波器', kernel_size=('X, X'), rngs=rngs)
self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.conv(x)
print(f'{x.shape = }') # 用于调试
x = nnx.relu(x)
print(f'{x.shape = }') # 用于调试
x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))
print(f'{x.shape = }') # 用于调试
x = x.reshape(x.shape[0], -1) # 展平
print(f'{x.shape = }') # 用于调试
x = self.linear(x)
return x
# 实例化 CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes='输出类别', rngs=nnx.Rngs(key))
# 虚拟输入
dummy_input = jnp.ones((1, 28, 28, 1))
# 前向传播
output = cnn(dummy_input)
print("输出形状:", output.shape)
# @title 练习 4 解决方案
class SimpleCNN(nnx.Module):
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.conv(x)
print(f'{x.shape = }')
x = nnx.relu(x)
print(f'{x.shape = }')
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
print(f'{x.shape = }')
x = x.reshape(x.shape[0], -1) # 展平
print(f'{x.shape = }')
x = self.linear(x)
return x
# 实例化 CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))
# 虚拟输入
dummy_input = jnp.ones((1, 28, 28, 1))
# 前向传播
output = cnn(dummy_input)
print("输出形状:", output.shape)
# @title 练习 5: 使用 Optax 的训练循环 (编程练习)
# 说明:
# 1. 定义一个简单的模型 (例如,一个线性层)。
# 2. 创建一个 nnx.Optimizer,确保使用现在必需的 wrt 参数指定要更新的变量类型 (例如,wrt=nnx.Param)。
# 3. 实现一个训练步骤函数,该函数:
# - 计算损失 (例如,均方误差)。
# - 使用 `nnx.value_and_grad` 计算梯度。
# - 使用 `optimizer.update(model, grads)` 更新模型的状态。
# 4. 运行训练循环几个步骤。
# 定义一个简单的模型
class LinearModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = '此处为线性层'
def __call__(self, x: jax.Array):
return self.linear(x)
# 实例化该模型
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))
# 创建一个 Optax 优化器
tx = '此处为 OPTAX SGD'
optimizer = nnx.Optimizer('包装优化器')
# 虚拟数据
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])
# 训练步骤函数
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss, model
# 训练循环
num_steps = 10
for i in range(num_steps):
loss, model = train_step(model, optimizer, x, y)
print(f"步骤 {i+1}, 损失: {loss}")
print("训练后的模型输出:", model(x))
# @title 练习 5 解决方案
# 定义一个简单的模型
class LinearModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)
def __call__(self, x: jax.Array):
return self.linear(x)
# 实例化该模型
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))
# 创建一个 Optax 优化器
tx = optax.sgd(learning_rate=0.01)
optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)
# 虚拟数据
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])
# 训练步骤函数
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss, model
# 训练循环
num_steps = 10
for i in range(num_steps):
loss, model = train_step(model, optimizer, x, y)
print(f"步骤 {i+1}, 损失: {loss}")
print("训练后的模型输出:", model(x))
请记住查阅官方文档以获取更深入的详细信息:
继续练习,祝您 JAX 编码愉快!
请通过 https://goo.gle/jax-training-feedback 向我们发送反馈