本笔记旨在配合“Leveraging the JAX AI Stack”讲座。您将亲身体验 JAX 核心概念、用于模型构建的 Flax NNX、用于优化的 Optax 以及用于检查点设置的 Orbax。
这些练习将指导您实现关键组件,并在适当时与 PyTorch 进行对比,以巩固您的理解。
让我们开始吧!
# @title 设置:安装并导入库
# 安装必要的库
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import orbax.checkpoint as ocp # 用于 Orbax
from typing import Any, Dict, Tuple # 用于类型提示
# 辅助函数,用于更美观地打印 PyTree 以供演示
import pprint
import os # 用于 Orbax 目录管理
import shutil # 用于清理 Orbax 目录
print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}")
print(f"Optax 版本: {optax.__version__}")
print(f"Orbax 版本: {ocp.__version__}")
# 全局 JAX PRNG 密钥,用于练习中的可复现性
# 学生可以学习为不同操作拆分此密钥。
main_key = jax.random.key(0)
# 练习 1 的说明
key_ex1, main_key = jax.random.split(main_key) # 拆分主密钥
# 1. 创建 JAX 数组 a 和 b
# TODO: 创建数组 'a' (2x2 随机正态分布) 和 'b' (2x2 全一矩阵)
a = None # 占位符
b = None # 占位符
print("数组 a:\n", a)
print("数组 b:\n", b)
# 2. 执行逐元素加法
# TODO: 将 a 和 b 相加
c = None # 占位符
print("逐元素和 c = a + b:\n", c)
# 3. 执行矩阵乘法
# TODO: 矩阵乘以 a 和 b
d = None # 占位符
print("矩阵乘积 d = a @ b:\n", d)
# 4. 演示不可变性
# original_a_id = id(a)
# print(f"原始 id(a): {original_a_id}")
# TODO: 执行重新分配 'a' 的操作,例如 a = a + 1
# a_new_ref = None # 占位符
# new_a_id = id(a_new_ref)
# print(f"'a = a + 1' 之后的新 id(a): {new_a_id}")
# TODO: 检查 original_a_id 是否与 new_a_id 不同
# print(f"ID 是否不同: {None}") # 占位符
# @title 解决方案 1: JAX 核心与 NumPy API
key_ex1_sol, main_key = jax.random.split(main_key)
# 1. 创建 JAX 数组 a 和 b
a_sol = jax.random.normal(key_ex1_sol, (2, 2))
b_sol = jnp.ones((2, 2))
print("数组 a:\n", a_sol)
print("数组 b:\n", b_sol)
# 2. 执行逐元素加法
c_sol = a_sol + b_sol
print("逐元素和 c = a + b:\n", c_sol)
# 3. 执行矩阵乘法
d_sol = jnp.dot(a_sol, b_sol) # 或 d = a @ b
print("矩阵乘积 d = a @ b:\n", d_sol)
# 4. 演示不可变性
original_a_id_sol = id(a_sol)
print(f"原始 id(a_sol): {original_a_id_sol}")
a_sol_new_ref = a_sol + 1 # 这会创建一个新数组并重新绑定 Python 变量。
new_a_id_sol = id(a_sol_new_ref)
print(f"'a_sol = a_sol + 1' 之后的新 id(a_sol_new_ref): {new_a_id_sol}")
print(f"ID 是否不同: {original_a_id_sol != new_a_id_sol}")
print("这表明原始数组未被就地修改;而是创建了一个新数组。")
# 练习 2 说明
key_ex2_main, main_key = jax.random.split(main_key)
key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)
# 1. 定义 Python 函数
def compute_heavy_stuff(x, w, b):
# TODO: 实现操作
y1 = None # 占位符
y2 = None # 占位符
y3 = None # 占位符
result = None # 占位符
return result
# 2. 创建 JIT 编译版本
# TODO: 使用 jax.jit 编译 compute_heavy_stuff
fast_compute_heavy_stuff = None # 占位符
# 3. 创建虚拟数据
dim1, dim2, dim3 = 500, 1000, 500
x_data = jax.random.normal(key_ex2_x, (dim1, dim2))
w_data = jax.random.normal(key_ex2_w, (dim2, dim3))
b_data = jax.random.normal(key_ex2_b, (dim3,))
# 4. 调用两个函数
result_original = None # 占位符 compute_heavy_stuff(x_data, w_data, b_data)
result_fast_first_call = None # 占位符 fast_compute_heavy_stuff(x_data, w_data, b_data) # 第一次调用(编译)
result_fast_second_call = None # 占位符 fast_compute_heavy_stuff(x_data, w_data, b_data) # 第二次调用(使用已编译代码)
print(f"结果 (原始): {result_original}")
print(f"结果 (快速, 第一次调用): {result_fast_first_call}")
print(f"结果 (快速, 第二次调用): {result_fast_second_call}")
# if result_original is not None and result_fast_first_call is not None:
# assert jnp.allclose(result_original, result_fast_first_call), "结果应该匹配!"
# print("\n原始函数和 JIT 编译函数的结果匹配。")
# 5. 可选:计时(为准确起见,在不同单元格中使用 %timeit)
# print("\n要查看速度差异,请在不同单元格中运行以下命令:")
# print("%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# print("%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# @title 解决方案 2: `jax.jit` (即时编译)
key_ex2_sol_main, main_key = jax.random.split(main_key)
key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)
# 1. 定义 Python 函数
def compute_heavy_stuff_sol(x, w, b):
y = jnp.dot(x, w)
y = y + b
y = jnp.tanh(y)
result = jnp.sum(y)
return result
# 2. 创建 JIT 编译版本
fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)
# 3. 创建虚拟数据
dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500
x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))
w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))
b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))
# 4. 调用两个函数
# 调用一次原始函数,以确保如果它是第一个 JAX 操作,不会因任何 JAX 开销而计时
result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 对 JIT 函数的第一次调用包含编译时间
result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 后续调用使用缓存的已编译代码
result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
print(f"结果 (原始): {result_original_sol}")
print(f"结果 (快速, 第一次调用): {result_fast_sol_first_call}")
print(f"结果 (快速, 第二次调用): {result_fast_sol_second_call}")
assert jnp.allclose(result_original_sol, result_fast_sol_first_call), "结果应该匹配!"
print("\n原始函数和 JIT 编译函数的结果匹配。")
# 5. 可选:计时
# 要准确测量,请在不同的 Colab 单元格中运行这些命令:
# 单元格 1:
# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 单元格 2:
# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 您应该观察到 JIT 编译的版本在初始编译后明显更快。
print("\n要查看速度差异,请在不同单元格中运行 %timeit 命令(如上注释中所示)。")
# 练习 3 说明
# 1. 定义 scalar_loss 函数
def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
# TODO: 实现预测和损失计算
y_pred = None # 占位符
loss = None # 占位符
return loss
# 2. 使用 jax.grad 创建梯度函数
# TODO: scalar_loss 关于 'params' 的梯度 (argnums=0)
compute_gradients = None # 占位符
# 3. 初始化虚拟数据
params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([7.0, 9.0, 11.0]) # 目标是 y = 3x + 4 (为了用 init_params 产生非零损失)
# 4. 调用梯度函数
gradients = None # 占位符 compute_gradients(params_init, x_input_data, y_target_data)
print("初始参数:", params_init)
print("关于参数的梯度:\n", gradients)
# 期望梯度(手动计算 y_pred = wx+b, loss = mean((y_pred - y_true)^2)):
# dL/dw = mean(2 * (wx+b - y_true) * x)
# dL/db = mean(2 * (wx+b - y_true) * 1)
# 对于 params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]
# x=1: y_pred = 2*1+1 = 3. 误差 = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8. dL/db_i_term = 2*(-4)*1 = -8
# x=2: y_pred = 2*2+1 = 5. 误差 = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8
# x=3: y_pred = 2*3+1 = 7. 误差 = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8
# 平均梯度: dL/dw = (-8-16-24)/3 = -48/3 = -16. dL/db = (-8-8-8)/3 = -24/3 = -8.
# if gradients is not None:
# assert jnp.isclose(gradients['w'], -16.0)
# assert jnp.isclose(gradients['b'], -8.0)
# print("\n梯度与期望值匹配。")
# @title 解决方案 3: `jax.grad` (自动微分)
# 1. 定义 scalar_loss 函数
def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
y_pred = params['w'] * x + params['b']
loss = jnp.mean((y_pred - y_true)**2)
return loss
# 2. 使用 jax.grad 创建梯度函数
# scalar_loss 关于 'params' 的梯度(即第 0 个参数)
compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)
# 3. 初始化虚拟数据
params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data_sol = jnp.array([1.0, 2.0, 3.0])
y_target_data_sol = jnp.array([7.0, 9.0, 11.0])
# 4. 调用梯度函数
gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)
print("初始参数:", params_init_sol)
print("关于参数的梯度:\n", pprint.pformat(gradients_sol))
# 与期望值验证(在说明中计算)
expected_dL_dw = -16.0
expected_dL_db = -8.0
assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f"关于 'w' 的梯度是 {gradients_sol['w']}, 期望值是 {expected_dL_dw}"
assert jnp.isclose(gradients_sol['b'], expected_dL_db), f"关于 'b' 的梯度是 {gradients_sol['b']}, 期望值是 {expected_dL_db}"
print("\n梯度与期望值匹配。")
# 练习 4 说明
key_ex4_main, main_key = jax.random.split(main_key)
key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)
# 1. 为单个向量定义 apply_affine
def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
# TODO: 计算 jnp.dot(matrix, vector) + bias
result = None # 占位符
return result
# 2. 准备数据
batch_size = 4
input_features = 3
output_features = 2
# batch_of_vectors: (batch_size, input_features)
# single_matrix: (output_features, input_features)
# single_bias: (output_features,)
batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))
single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))
single_bias = jax.random.normal(key_ex4_bias, (output_features,))
# 3. 使用 jax.vmap 创建 batched_apply_affine
# TODO: 正确指定 in_axes:向量是批处理的,矩阵和偏置不是。out_axes 应为 0。
batched_apply_affine = None # 占位符 jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)
# 4. 测试 batched_apply_affine
result_vmap = None # 占位符 batched_apply_affine(batch_of_vectors, single_matrix, single_bias)
print("一批向量的形状:", batch_of_vectors.shape)
print("单个矩阵的形状:", single_matrix.shape)
print("单个偏置的形状:", single_bias.shape)
if result_vmap is not None:
print("使用 vmap 的结果形状:", result_vmap.shape) # 期望: (batch_size, output_features)
# 为了比较,一个手动循环(效率较低):
# manual_results = []
# for i in range(batch_size):
# manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))
# result_manual_loop = jnp.stack(manual_results)
# assert jnp.allclose(result_vmap, result_manual_loop)
# print("vmap 结果与手动循环结果匹配。")
else:
print("result_vmap 是 None。")
# @title 解决方案 4: `jax.vmap` (自动向量化)
key_ex4_sol_main, main_key = jax.random.split(main_key)
key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)
# 1. 为单个向量定义 apply_affine
def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
return jnp.dot(matrix, vector) + bias
# 2. 准备数据
batch_size_sol = 4
input_features_sol = 3
output_features_sol = 2
batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))
single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))
single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))
# 3. 使用 jax.vmap 创建 batched_apply_affine
# 向量沿轴 0 进行批处理,矩阵和偏置不进行批处理(广播)。
# out_axes=0 表示输出也将沿其第一个轴进行批处理。
batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)
# 4. 测试 batched_apply_affine
result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)
print("一批向量的形状:", batch_of_vectors_sol.shape)
print("单个矩阵的形状:", single_matrix_sol.shape)
print("单个偏置的形状:", single_bias_sol.shape)
print("使用 vmap 的结果形状:", result_vmap_sol.shape) # 期望: (batch_size, output_features)
assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)
# 为了比较,一个手动循环(效率较低):
manual_results_sol = []
for i in range(batch_size_sol):
manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))
result_manual_loop_sol = jnp.stack(manual_results_sol)
assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)
print("\nvmap 结果与手动循环结果匹配,表明向量化正确。")
# 练习 5 说明
key_ex5_model_init, main_key = jax.random.split(main_key)
# 1, 2, 3. 定义 SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# TODO: 定义一个名为 'dense_layer' 的 nnx.Linear 层
# self.dense_layer = nnx.Linear(...)
self.some_attribute = None # 占位符,稍后删除
pass # 如果类不为空,则删除此占位符
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# TODO: 将输入 x 传递给 dense_layer
# return self.dense_layer(x)
return x # 占位符
# 4. 实例化模型
model_din = 3
model_dout = 2
# TODO: 为参数初始化创建 nnx.Rngs。使用 'params' 作为密钥名称。
model_rngs = None # 占位符 nnx.Rngs(params=key_ex5_model_init)
my_model = None # 占位符 SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)
# 5. 使用虚拟数据进行测试
dummy_batch_size = 4
dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))
model_output = None # 占位符
if my_model is not None:
model_output = my_model(dummy_input_ex5)
print(f"模型输出形状: {model_output.shape}")
print(f"模型输出:\n{model_output}")
model_state = my_model.get_state()
print(f"\n模型状态 (参数等):")
pprint.pprint(model_state)
else:
print("my_model 是 None。")
# @title 解决方案 5: Flax NNX - 定义模型
key_ex5_sol_model_init, main_key = jax.random.split(main_key)
# 1, 2, 3. 定义 SimpleNNXModel
class SimpleNNXModel_Sol(nnx.Module): # 为解决方案单元重命名
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# nnx.Linear 默认会使用 rngs 中的 'params' 密钥来初始化其参数
self.dense_layer = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.dense_layer(x)
# 4. 实例化模型
model_din_sol = 3
model_dout_sol = 2
# 为参数初始化创建 nnx.Rngs。
# 'params' 是 nnx.Linear 在 rngs 对象中查找的默认密钥。
model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)
my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)
# 5. 使用虚拟数据进行测试
dummy_batch_size_sol = 4
dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))
model_output_sol = my_model_sol(dummy_input_ex5_sol)
print(f"模型输出形状: {model_output_sol.shape}")
print(f"模型输出:\n{model_output_sol}")
# model_state_sol = my_model_sol.get_state()
_, model_state_sol = nnx.split(my_model_sol)
print(f"\n模型状态 (参数等):")
nnx.display(model_state_sol)
# 检查参数是否存在
assert 'dense_layer' in model_state_sol, "在 model_state 中找不到密钥 'dense_layer'"
assert 'kernel' in model_state_sol['dense_layer'], "在 model_state['dense_layer'] 中找不到密钥 'kernel'"
assert 'bias' in model_state_sol['dense_layer'], "在 model_state['dense_layer'] 中找不到密钥 'bias'"
print("\n模型参数(dense_layer 的内核和偏置)存在于状态中。")
# 练习 6 说明
# 1. 假设 my_model_sol 可从练习 5 的解决方案中获得
# (如果独立运行,请重新实例化它)
if 'my_model_sol' not in globals():
print("为 Ex6 重新初始化 Ex5 解决方案中的模型。")
key_ex6_model_init, main_key = jax.random.split(main_key)
_model_din_ex6 = 3
_model_dout_ex6 = 2
_model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)
# 如果已定义,则使用解决方案类名,否则使用学生的类名
_ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)
print("为优化器创建了模型。")
else:
model_for_opt = my_model_sol # 使用上一个解决方案中的模型
print("使用上一个练习中的 'my_model_sol' 作为 'model_for_opt'。")
# 2. 创建一个 Optax 优化器
learning_rate = 0.001
# TODO: 创建一个 optax.adam 优化器转换
optax_tx = None # 占位符 optax.adam(...)
# 3. 创建一个 nnx.Optimizer 包装器
# TODO: 包装模型 (model_for_opt) 和 optax 转换 (optax_tx)
# 现在需要 `wrt` 参数来指定要对什么进行微分。
nnx_optimizer = None # 占位符 nnx.Optimizer(...)
# 4. 打印优化器及其状态
print("\nFlax NNX Optimizer 包装器:")
nnx.display(nnx_optimizer)
print("\n初始优化器状态 (Optax 状态,例如 Adam 的动量):")
if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):
pprint.pprint(nnx_optimizer.state)
# if hasattr(nnx_optimizer, 'opt_state'):
# adam_state = nnx_optimizer.opt_state
# assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')
# print("\n优化器状态结构对于 Adam 似乎是合理的。")
else:
print("nnx_optimizer 或其状态为 None 或结构不符合预期。")
# @title 解决方案 6: Optax 和 Flax NNX - 创建优化器
# 1. 使用练习 5 解决方案中的 my_model_sol
# 如果不是按顺序运行,请确保已定义 my_model_sol:
if 'my_model_sol' not in globals():
print("为 Ex6 解决方案重新初始化 Ex5 解决方案中的模型。")
key_ex6_sol_model_init, main_key = jax.random.split(main_key)
_model_din_sol_ex6 = 3
_model_dout_sol_ex6 = 2
_model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)
# 确保使用 SimpleNNXModel_Sol
my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)
print("为优化器重新创建了模型 'my_model_sol'。")
else:
print("使用上一个练习中的模型 'my_model_sol'。")
# 2. 创建一个 Optax 优化器
learning_rate_sol = 0.001
# 创建一个 optax.adam 优化器转换
optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)
# 3. 创建一个 nnx.Optimizer 包装器
# 这将模型和 Optax 优化器连接起来。
# 优化器状态将根据模型的参数进行初始化。
nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)
# 4. 打印优化器及其状态
print("\nFlax NNX Optimizer 包装器:")
nnx.display(nnx_optimizer_sol) # 显示它关联的模型和 Optax 转换
print("\n初始优化器状态 (Optax 状态,例如 Adam 的动量):")
# nnx.Optimizer 将实际的 Optax 状态存储在其 .opt_state 属性中。
# 此状态是一个与模型参数结构匹配的 PyTree。
pprint.pprint(nnx_optimizer_sol.opt_state)
# 验证 Adam 的优化器状态结构(每个参数的 count、mu、nu)
assert hasattr(nnx_optimizer_sol, 'opt_state'), "在 nnx.Optimizer 中找不到 Optax opt_state"
# opt_state 是一个元组,对于 adam 通常是 (CountState(), ScaleByAdamState())
adam_optax_internal_state = nnx_optimizer_sol.opt_state
assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), "未找到 Adam 'count' 状态。"
# 元组的第二个元素通常是参数特定状态(如 mu 和 nu)所在的位置
if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):
param_specific_state = adam_optax_internal_state[1]
assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], "未找到内核的 Adam 'mu' 状态。"
print("\n优化器状态结构对于 Adam 是正确的。")
else:
print("\n警告:Adam 的优化器状态结构可能不同或未完全验证。")
# 练习 7 说明
key_ex7_main, main_key = jax.random.split(main_key)
key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)
# 1. 使用上一个练习解决方案中的模型和优化器
# 确保 my_model_sol 和 nnx_optimizer_sol 可用
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
print("为 Ex7 重新初始化 Ex5/Ex6 解决方案中的模型和优化器。")
key_ex7_model_fallback, main_key = jax.random.split(main_key)
_model_din_ex7 = 3
_model_dout_ex7 = 2
_model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)
# 确保使用 SimpleNNXModel_Sol
my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)
_optax_tx_ex7 = optax.adam(learning_rate=0.001)
nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)
print("为 Ex7 重新创建了模型和优化器。")
else:
my_model_ex7 = my_model_sol
nnx_optimizer_ex7 = nnx_optimizer_sol
print("使用 'my_model_sol' 和 'nnx_optimizer_sol' 作为 'my_model_ex7' 和 'nnx_optimizer_ex7'。")
# 2, 3. 定义 train_step 函数
# TODO: 使用 @nnx.jit 装饰
# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # 使用 nnx.Module 基类进行类型提示
# x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# TODO: 定义内部 loss_fn_for_grad(current_model_state_for_grad_fn)
# def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # 使用 nnx.Module 基类进行类型提示
# y_pred = model_in_grad_fn(x_batch)
# loss = jnp.mean((y_pred - y_batch)**2)
# return loss
# return jnp.array(0.0) # 占位符
# TODO: 使用 nnx.value_and_grad 计算损失值和梯度
# loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # 传递 model_arg
# TODO: 更新优化器(这将就地更新 model_arg)
# optimizer_arg.update(model_arg, grads)
# return loss_value
# return jnp.array(0.0) # 占位符定义的 train_step 函数
# 供学生定义:
# 确保函数签名对于 nnx.jit 是正确的
@nnx.jit
def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# 供学生使用的占位符实现
def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
# y_pred = model_in_grad_fn(x_batch)
# loss = jnp.mean((y_pred - y_batch)**2)
# return loss
return jnp.array(0.0) # 学生 TODO:替换此行
# loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
# optimizer_arg.update(grads)
# return loss_value
return jnp.array(-1.0) # 学生 TODO:替换此行
# 4. 创建虚拟数据
batch_s = 8
# 仔细访问 features_in 和 features_out
_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3
_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2
x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))
y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))
# 可选:存储初始参数值以进行比较
initial_kernel_val = None
if hasattr(my_model_ex7, 'get_state'):
_current_model_state_ex7 = my_model_ex7.get_state()
if 'dense_layer' in _current_model_state_ex7:
initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()
print(f"初始内核值(样本): {initial_kernel_val}")
# 5. 调用 train_step
# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # 学生将取消注释
loss_after_step = jnp.array(-1.0) # 占位符,直到学生实现 train_step
if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # 检查学生是否已实现
loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)
print(f"一个训练步骤后的损失: {loss_after_step}")
else:
print("学生需要实现 `train_step` 函数。")
# # 6. 可选:验证参数更改
# updated_kernel_val_sol = None
# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # 更新后再次获取状态
# if 'dense_layer' in updated_model_state_sol:
# updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
# print(f"更新后的内核值(样本): {updated_kernel_val_sol}")
# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
# assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "内核参数未更改!"
# print("内核参数在训练步骤后按预期更改。")
# else:
# print("无法验证内核更改(初始值或更新值为 None)。")
# @title 解决方案 7: 使用 Flax NNX 和 Optax 进行训练步骤
key_ex7_sol_main, main_key = jax.random.split(main_key)
key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)
# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
print("为 Ex7 解决方案重新初始化 Ex5/Ex6 解决方案中的模型和优化器。")
key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)
_model_din_sol_ex7 = 3
_model_dout_sol_ex7 = 2
_model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)
# 确保为解决方案使用 SimpleNNXModel_Sol
my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)
_optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)
nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)
print("为 Ex7 解决方案重新创建了模型和优化器。")
else:
# 如果按顺序运行解决方案,这些将是正确的实例
my_model_sol_ex7 = my_model_sol
nnx_optimizer_sol_ex7 = nnx_optimizer_sol
print("为 Ex7 解决方案使用 'my_model_sol' 和 'nnx_optimizer_sol'。")
# 2, 3. 定义 train_step 函数
@nnx.jit # 使用 @nnx.jit 装饰器进行 JIT 编译
def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # 为通用性使用基类 nnx.Module
x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# 定义内部的 loss_fn_for_grad。它将模型作为其第一个参数。
# 它从外部作用域捕获 x_batch 和 y_batch。
def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # 使用基类 nnx.Module
y_pred = model_in_grad_fn(x_batch) # 使用传递给此内部函数的模型
loss = jnp.mean((y_pred - y_batch)**2)
return loss
# 使用 nnx.value_and_grad 计算损失值和梯度。
# 这将对 loss_fn_for_grad 关于其第一个参数 (model_in_grad_fn) 进行微分。
# 我们将模型的当前状态 (model_arg) 传递给它。
loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
# 更新优化器。这将就地更新 model_arg(nnx_optimizer_sol_ex7 所引用的)。
optimizer_arg.update(model_arg, grads)
return loss_value
# 4. 创建虚拟数据
batch_s_sol = 8
# 确保 din 和 dout 与 Ex5/Ex6 的模型实例化匹配
# my_model_sol_ex7.dense_layer 是一个 nnx.Linear 对象
din_from_model_sol = my_model_sol_ex7.dense_layer.in_features
dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features
x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))
y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))
# 可选:存储初始参数值以进行比较
initial_kernel_val_sol = None
_, current_model_state_sol = nnx.split(my_model_sol_ex7)
if 'dense_layer' in current_model_state_sol:
initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()
print(f"初始内核值(样本): {initial_kernel_val_sol}")
# 5. 调用 train_step
# 第一次调用将 JIT 编译 train_step_sol 函数。
loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"一个训练步骤后的损失 (第一次调用, JIT): {loss_after_step_sol}")
# 第二次调用以显示它更快(尽管 %timeit 更适合测量)
loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"一个训练步骤后的损失 (第二次调用, 缓存): {loss_after_step_sol_2}")
# 6. 可选:验证参数更改
updated_kernel_val_sol = None
_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # 更新后再次获取状态
if 'dense_layer' in updated_model_state_sol:
updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
print(f"更新后的内核值(样本): {updated_kernel_val_sol}")
if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "内核参数未更改!"
print("内核参数在训练步骤后按预期更改。")
else:
print("无法验证内核更改(初始值或更新值为 None)。")
# 练习 8 说明
# import orbax.checkpoint as ocp # 已导入
# import os, shutil # 已导入
# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
print("为 Ex8 重新初始化 Ex7 解决方案中的模型和优化器。")
key_ex8_model_fallback, main_key = jax.random.split(main_key)
_model_din_ex8 = 3
_model_dout_ex8 = 2
_model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)
_ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)
_optax_tx_ex8 = optax.adam(learning_rate=0.001)
optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)
print("为 Ex8 重新创建了模型和优化器。")
else:
model_to_save = my_model_sol_ex7
optimizer_to_save = nnx_optimizer_sol_ex7
print("使用 Ex7 解决方案中的模型和优化器。")
# 2. 定义检查点目录
# TODO: 定义 checkpoint_dir
checkpoint_dir = None # 占位符,例如 "/tmp/my_nnx_checkpoint_exercise/"
# if checkpoint_dir and os.path.exists(checkpoint_dir):
# shutil.rmtree(checkpoint_dir) # 为安全起见,清理以前的运行
# if checkpoint_dir:
# os.makedirs(checkpoint_dir, exist_ok=True)
# 3. 创建 Orbax CheckpointManager
# TODO: 创建选项和管理器
# options = ocp.CheckpointManagerOptions(...)
# mngr = ocp.CheckpointManager(...)
options = None
mngr = None
# 4. 捆绑状态
# current_step = 100 # 示例步骤
# TODO: 获取 model_state 和 optimizer_state
# model_state_to_save = nnx.split(model_to_save)
# 优化器状态现在通过 .state 属性访问。
# opt_state_to_save = optimizer_to_save.state
# save_bundle = {
# 'model': model_state_to_save,
# 'optimizer': opt_state_to_save,
# 'step': current_step
# }
save_bundle = None
# 5. 保存检查点
# if mngr and save_bundle:
# TODO: 保存检查点
# mngr.save(...)
# mngr.wait_until_finished()
# print(f"检查点已在步骤 {current_step} 保存到 {checkpoint_dir}")
# else:
# print("检查点管理器或 save_bundle 未初始化。")
# --- 恢复 ---
# 6.a 创建新模型和 Optax 转换(用于恢复)
# key_ex8_restore_model, main_key = jax.random.split(main_key)
# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3
# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2
# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
# restored_model = _ModelClassRestore(
# din=din_restore, dout=dout_restore,
# rngs=nnx.Rngs(params=key_ex8_restore_model) # 用于不同初始参数的新密钥
# )
# restored_optax_tx = optax.adam(learning_rate=0.001) # 相同的 Optax 配置
restored_model = None
restored_optax_tx = None
# 6.b 恢复检查点
# loaded_bundle = None
# if mngr:
# TODO: 恢复检查点
# latest_step = mngr.latest_step()
# if latest_step is not None:
# loaded_bundle = mngr.restore(...)
# print(f"已从步骤 {latest_step} 恢复检查点")
# else:
# print("未找到要恢复的检查点。")
# else:
# print("检查点管理器未初始化以进行恢复。")
# 6.c 应用加载的状态
# if loaded_bundle and restored_model:
# TODO: 更新 restored_model 状态
# nnx.update(restored_model, ...)
# print("已应用恢复的模型状态。")
# TODO: 创建新的 nnx.Optimizer 并分配其状态
# restored_optimizer = nnx.Optimizer(...)
# restored_optimizer.state = ...
# print("已应用恢复的优化器状态。")
# else:
# print("loaded_bundle 或 restored_model 为 None,无法应用状态。")
restored_optimizer = None
# 7. 验证恢复
# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
# _, restored_model_state = nnx.split(restored_model_sol)
# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
# "恢复后模型内核参数不同!"
# print("\n模型参数已成功恢复和验证(内核匹配)。")
# # 验证优化器状态(例如,特定参数的 Adam 'mu')
# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
# "优化器 Adam mu for kernel 不同!"
# print("优化器状态(样本 mu)已成功恢复和验证。")
# 8. 清理
# if mngr:
# mngr.close()
# if checkpoint_dir and os.path.exists(checkpoint_dir):
# shutil.rmtree(checkpoint_dir)
# print(f"已清理检查点目录:{checkpoint_dir}")
# @title 解决方案 8: Orbax - 保存和恢复检查点
# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
print("为 Ex8 解决方案重新初始化 Ex7 解决方案中的模型和优化器。")
key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)
_model_din_sol_ex8 = 3
_model_dout_sol_ex8 = 2
_model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)
# 确保为解决方案使用 SimpleNNXModel_Sol
model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,
dout=_model_dout_sol_ex8,
rngs=_model_rngs_sol_ex8)
_optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # 存储转换以供稍后使用
optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)
print("为 Ex8 解决方案重新创建了模型和优化器。")
else:
model_to_save_sol = my_model_sol_ex7
optimizer_to_save_sol = nnx_optimizer_sol_ex7
# 我们需要用于创建优化器的 optax 转换来进行恢复
_optax_tx_sol_ex8 = optimizer_to_save_sol.tx # 访问原始 Optax 转换
print("为 Ex8 解决方案使用 Ex7 解决方案中的模型和优化器。")
# 2. 定义检查点目录
checkpoint_dir_sol = "/tmp/my_nnx_checkpoint_exercise_solution/"
if os.path.exists(checkpoint_dir_sol):
shutil.rmtree(checkpoint_dir_sol) # 清理以前的运行
os.makedirs(checkpoint_dir_sol, exist_ok=True)
print(f"Orbax 检查点目录: {checkpoint_dir_sol}")
# 3. 创建 Orbax CheckpointManager
options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)
mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)
# 4. 捆绑状态
current_step_sol = 100 # 示例步骤
_, model_state_to_save_sol = nnx.split(model_to_save_sol)
# 优化器状态现在是 .state 属性中直接可用的 PyTree。
opt_state_to_save_sol = optimizer_to_save_sol.opt_state
save_bundle_sol = {
'model': model_state_to_save_sol,
'optimizer': opt_state_to_save_sol,
'step': current_step_sol
}
print("\n要保存的状态包:")
pprint.pprint(f"模型状态键: {model_state_to_save_sol.keys()}")
pprint.pprint(f"优化器状态类型: {type(opt_state_to_save_sol)}")
# 5. 保存检查点
mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))
mngr_sol.wait_until_finished()
print(f"\n检查点已在步骤 {current_step_sol} 保存到 {checkpoint_dir_sol}")
# --- 恢复 ---
# 6.a 创建新模型和 Optax 转换(用于恢复)
key_ex8_sol_restore_model, main_key = jax.random.split(main_key)
# 确保如果可能,从已保存模型的结构中正确获取 din/dout
# 假设 model_to_save_sol 是具有 dense_layer 的 SimpleNNXModel_Sol
din_restore_sol = model_to_save_sol.dense_layer.in_features
dout_restore_sol = model_to_save_sol.dense_layer.out_features
restored_model_sol = SimpleNNXModel_Sol( # 使用解决方案的模型类
din=din_restore_sol, dout=dout_restore_sol,
rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # 用于不同初始参数的新密钥
)
# 我们需要原始的 Optax 转换定义来创建新的 nnx.Optimizer
# _optax_tx_sol_ex8 之前已存储,或者如果配置已知,则可以重新创建
restored_optax_tx_sol = _optax_tx_sol_ex8
# 在恢复前打印新模型的参数,以显示它不同
_, kernel_before_restore_sol = nnx.split(restored_model_sol)
print(f"\n恢复前 'restored_model_sol' 的示例内核:")
nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])
# 6.b 恢复检查点
loaded_bundle_sol = None
latest_step_sol = mngr_sol.latest_step()
if latest_step_sol is not None:
# 对于 NNX,我们恢复原始 PyTree,StandardRestore 是合适的。
loaded_bundle_sol = mngr_sol.restore(latest_step_sol,
args=ocp.args.StandardRestore(save_bundle_sol))
print(f"\n已从步骤 {latest_step_sol} 恢复检查点")
print(f"加载的包包含键: {loaded_bundle_sol.keys()}")
else:
raise ValueError("未找到要恢复的检查点。")
# 6.c 应用加载的状态
assert loaded_bundle_sol is not None, "加载的包为 None"
nnx.update(restored_model_sol, loaded_bundle_sol['model'])
print("已将恢复的模型状态应用于 'restored_model_sol'。")
# 使用 restored_model 和原始 optax_tx 创建新的 nnx.Optimizer
restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,
wrt=nnx.Param)
# 现在分配加载的 Optax 状态 PyTree
restored_optimizer_sol.state = loaded_bundle_sol['optimizer']
print("已将恢复的优化器状态应用于 'restored_optimizer_sol'。")
# 7. 验证恢复
original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
_, restored_model_state = nnx.split(restored_model_sol)
kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
"恢复后模型内核参数不同!"
print("\n模型参数已成功恢复和验证(内核匹配)。")
# 验证优化器状态(例如,特定参数的 Adam 'mu')
original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
"优化器 Adam mu for kernel 不同!"
print("优化器状态(样本 mu)已成功恢复和验证。")
# 8. 清理
mngr_sol.close()
if os.path.exists(checkpoint_dir_sol):
shutil.rmtree(checkpoint_dir_sol)
print(f"已清理检查点目录:{checkpoint_dir_sol}")
您现在已经亲身体验了:
这为使用 JAX 生态系统开发高性能机器学习模型奠定了坚实的基础。
要进一步学习,请参阅官方文档:
不要忘记对培训课程提供反馈: https://goo.gle/jax-training-feedback