简介

欢迎来到 JAX AI 技术栈练习!

本笔记旨在配合“Leveraging the JAX AI Stack”讲座。您将亲身体验 JAX 核心概念、用于模型构建的 Flax NNX、用于优化的 Optax 以及用于检查点设置的 Orbax。

这些练习将指导您实现关键组件,并在适当时与 PyTorch 进行对比,以巩固您的理解。

让我们开始吧!

# @title 设置:安装并导入库
# 安装必要的库
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import orbax.checkpoint as ocp # 用于 Orbax
from typing import Any, Dict, Tuple # 用于类型提示

# 辅助函数,用于更美观地打印 PyTree 以供演示
import pprint
import os # 用于 Orbax 目录管理
import shutil # 用于清理 Orbax 目录

print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}")
print(f"Optax 版本: {optax.__version__}")
print(f"Orbax 版本: {ocp.__version__}")

# 全局 JAX PRNG 密钥,用于练习中的可复现性
# 学生可以学习为不同操作拆分此密钥。
main_key = jax.random.key(0)

练习 1: JAX 核心与 NumPy API

目标: 熟悉 jax.numpy 和 JAX 的函数式编程风格。

说明:

  1. 使用 jax.numpy (jnp) 创建两个 JAX 数组,a(一个 2x2 的随机数矩阵)和 b(一个 2x2 的全一矩阵)。您需要一个 jax.random.key 来创建随机数。
  2. 对 a 和 b 进行逐元素加法。
  3. 对 a 和 b 进行矩阵乘法。
  4. 演示 JAX 的不可变性:
- 存储数组 a 的 Python id()。 - 执行类似 a = a + 1 的操作。 - 打印 a 的新 id(),并观察它已更改,表明创建了一个新数组。
# 练习 1 的说明
key_ex1, main_key = jax.random.split(main_key) # 拆分主密钥

# 1. 创建 JAX 数组 a 和 b
# TODO: 创建数组 'a' (2x2 随机正态分布) 和 'b' (2x2 全一矩阵)
a = None # 占位符
b = None # 占位符

print("数组 a:\n", a)
print("数组 b:\n", b)

# 2. 执行逐元素加法
# TODO: 将 a 和 b 相加
c = None # 占位符
print("逐元素和 c = a + b:\n", c)

# 3. 执行矩阵乘法
# TODO: 矩阵乘以 a 和 b
d = None # 占位符
print("矩阵乘积 d = a @ b:\n", d)

# 4. 演示不可变性
# original_a_id = id(a)
# print(f"原始 id(a): {original_a_id}")

# TODO: 执行重新分配 'a' 的操作,例如 a = a + 1
# a_new_ref = None # 占位符
# new_a_id = id(a_new_ref)
# print(f"'a = a + 1' 之后的新 id(a): {new_a_id}")

# TODO: 检查 original_a_id 是否与 new_a_id 不同
# print(f"ID 是否不同: {None}") # 占位符
# @title 解决方案 1: JAX 核心与 NumPy API
key_ex1_sol, main_key = jax.random.split(main_key)

# 1. 创建 JAX 数组 a 和 b
a_sol = jax.random.normal(key_ex1_sol, (2, 2))
b_sol = jnp.ones((2, 2))

print("数组 a:\n", a_sol)
print("数组 b:\n", b_sol)

# 2. 执行逐元素加法
c_sol = a_sol + b_sol
print("逐元素和 c = a + b:\n", c_sol)

# 3. 执行矩阵乘法
d_sol = jnp.dot(a_sol, b_sol) # 或 d = a @ b
print("矩阵乘积 d = a @ b:\n", d_sol)

# 4. 演示不可变性
original_a_id_sol = id(a_sol)
print(f"原始 id(a_sol): {original_a_id_sol}")

a_sol_new_ref = a_sol + 1 # 这会创建一个新数组并重新绑定 Python 变量。
new_a_id_sol = id(a_sol_new_ref)
print(f"'a_sol = a_sol + 1' 之后的新 id(a_sol_new_ref): {new_a_id_sol}")
print(f"ID 是否不同: {original_a_id_sol != new_a_id_sol}")
print("这表明原始数组未被就地修改;而是创建了一个新数组。")

练习 2: jax.jit (即时编译)

目标: 了解如何使用 jax.jit 编译 JAX 函数以提高性能。

说明:

  1. 定义一个 Python 函数 `compute_heavy_stuff(x, w, b)`,该函数执行一系列 `jnp` 操作:
- y = jnp.dot(x, w) - y = y + b - y = jnp.tanh(y) - result = jnp.sum(y) - 返回 result。
  1. 使用 `jax.jit` 创建此函数的 JIT 编译版本 `fast_compute_heavy_stuff`。
  2. 为 x、w 和 b 创建一些大型虚拟 JAX 数组。
  3. 使用虚拟数据调用原始函数和 JIT 编译函数。
  4. (可选) 在 Colab 的不同单元格中使用 `%timeit` 魔法命令比较它们的执行速度。请记住,对 JIT 编译函数的第一次调用包含编译时间。
# 练习 2 说明
key_ex2_main, main_key = jax.random.split(main_key)
key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)

# 1. 定义 Python 函数
def compute_heavy_stuff(x, w, b):
    # TODO: 实现操作
    y1 = None # 占位符
    y2 = None # 占位符
    y3 = None # 占位符
    result = None # 占位符
    return result

# 2. 创建 JIT 编译版本
# TODO: 使用 jax.jit 编译 compute_heavy_stuff
fast_compute_heavy_stuff = None # 占位符

# 3. 创建虚拟数据
dim1, dim2, dim3 = 500, 1000, 500
x_data = jax.random.normal(key_ex2_x, (dim1, dim2))
w_data = jax.random.normal(key_ex2_w, (dim2, dim3))
b_data = jax.random.normal(key_ex2_b, (dim3,))

# 4. 调用两个函数
result_original = None # 占位符 compute_heavy_stuff(x_data, w_data, b_data)
result_fast_first_call = None # 占位符 fast_compute_heavy_stuff(x_data, w_data, b_data) # 第一次调用(编译)
result_fast_second_call = None # 占位符 fast_compute_heavy_stuff(x_data, w_data, b_data) # 第二次调用(使用已编译代码)

print(f"结果 (原始): {result_original}")
print(f"结果 (快速, 第一次调用): {result_fast_first_call}")
print(f"结果 (快速, 第二次调用): {result_fast_second_call}")

# if result_original is not None and result_fast_first_call is not None:
#   assert jnp.allclose(result_original, result_fast_first_call), "结果应该匹配!"
#   print("\n原始函数和 JIT 编译函数的结果匹配。")

# 5. 可选:计时(为准确起见,在不同单元格中使用 %timeit)
# print("\n要查看速度差异,请在不同单元格中运行以下命令:")
# print("%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# print("%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# @title 解决方案 2: `jax.jit` (即时编译)
key_ex2_sol_main, main_key = jax.random.split(main_key)
key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)

# 1. 定义 Python 函数
def compute_heavy_stuff_sol(x, w, b):
    y = jnp.dot(x, w)
    y = y + b
    y = jnp.tanh(y)
    result = jnp.sum(y)
    return result

# 2. 创建 JIT 编译版本
fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)

# 3. 创建虚拟数据
dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500
x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))
w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))
b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))

# 4. 调用两个函数
# 调用一次原始函数,以确保如果它是第一个 JAX 操作,不会因任何 JAX 开销而计时
result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

# 对 JIT 函数的第一次调用包含编译时间
result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

# 后续调用使用缓存的已编译代码
result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

print(f"结果 (原始): {result_original_sol}")
print(f"结果 (快速, 第一次调用): {result_fast_sol_first_call}")
print(f"结果 (快速, 第二次调用): {result_fast_sol_second_call}")

assert jnp.allclose(result_original_sol, result_fast_sol_first_call), "结果应该匹配!"
print("\n原始函数和 JIT 编译函数的结果匹配。")

# 5. 可选:计时
# 要准确测量,请在不同的 Colab 单元格中运行这些命令:
# 单元格 1:
# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 单元格 2:
# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# 您应该观察到 JIT 编译的版本在初始编译后明显更快。
print("\n要查看速度差异,请在不同单元格中运行 %timeit 命令(如上注释中所示)。")

练习 3: jax.grad (自动微分)

目标: 学习使用 jax.grad 计算函数的梯度。

说明:

  1. 定义一个 Python 函数 `scalar_loss(params, x, y_true)`,它:
- 接受一个字典 `params`,其中包含键 'w' 和 'b'。 - 计算 `y_pred = params['w'] * x + params['b']`。 - 返回一个标量损失,例如 `jnp.mean((y_pred - y_true)**2)`。
  1. 使用 `jax.grad` 创建一个新函数 `compute_gradients`,该函数计算 `scalar_loss` 关于其第一个参数 (params) 的梯度。
  2. 初始化一些虚拟的 `params`、`x_input` 和 `y_target` 值。
  3. 调用 `compute_gradients` 获取梯度。打印梯度。
# 练习 3 说明

# 1. 定义 scalar_loss 函数
def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
    # TODO: 实现预测和损失计算
    y_pred = None # 占位符
    loss = None # 占位符
    return loss

# 2. 使用 jax.grad 创建梯度函数
# TODO: scalar_loss 关于 'params' 的梯度 (argnums=0)
compute_gradients = None # 占位符

# 3. 初始化虚拟数据
params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([7.0, 9.0, 11.0]) # 目标是 y = 3x + 4 (为了用 init_params 产生非零损失)

# 4. 调用梯度函数
gradients = None # 占位符 compute_gradients(params_init, x_input_data, y_target_data)
print("初始参数:", params_init)
print("关于参数的梯度:\n", gradients)

# 期望梯度(手动计算 y_pred = wx+b, loss = mean((y_pred - y_true)^2)):
# dL/dw = mean(2 * (wx+b - y_true) * x)
# dL/db = mean(2 * (wx+b - y_true) * 1)
# 对于 params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]
# x=1: y_pred = 2*1+1 = 3. 误差 = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8.  dL/db_i_term = 2*(-4)*1 = -8
# x=2: y_pred = 2*2+1 = 5. 误差 = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8
# x=3: y_pred = 2*3+1 = 7. 误差 = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8
# 平均梯度: dL/dw = (-8-16-24)/3 = -48/3 = -16.  dL/db = (-8-8-8)/3 = -24/3 = -8.
# if gradients is not None:
#     assert jnp.isclose(gradients['w'], -16.0)
#     assert jnp.isclose(gradients['b'], -8.0)
#     print("\n梯度与期望值匹配。")
# @title 解决方案 3: `jax.grad` (自动微分)

# 1. 定义 scalar_loss 函数
def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
    y_pred = params['w'] * x + params['b']
    loss = jnp.mean((y_pred - y_true)**2)
    return loss

# 2. 使用 jax.grad 创建梯度函数
# scalar_loss 关于 'params' 的梯度(即第 0 个参数)
compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)

# 3. 初始化虚拟数据
params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data_sol = jnp.array([1.0, 2.0, 3.0])
y_target_data_sol = jnp.array([7.0, 9.0, 11.0])

# 4. 调用梯度函数
gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)
print("初始参数:", params_init_sol)
print("关于参数的梯度:\n", pprint.pformat(gradients_sol))

# 与期望值验证(在说明中计算)
expected_dL_dw = -16.0
expected_dL_db = -8.0
assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f"关于 'w' 的梯度是 {gradients_sol['w']}, 期望值是 {expected_dL_dw}"
assert jnp.isclose(gradients_sol['b'], expected_dL_db), f"关于 'b' 的梯度是 {gradients_sol['b']}, 期望值是 {expected_dL_db}"
print("\n梯度与期望值匹配。")

练习 4: jax.vmap (自动向量化)

目标: 使用 jax.vmap 自动批处理操作。

说明:

  1. 定义一个函数 `apply_affine(vector, matrix, bias)`,它接受一个一维向量、一个二维矩阵和一个一维偏置。它应该计算 `jnp.dot(matrix, vector) + bias`。
  2. 您有一批向量(一个二维数组,其中每行是一个向量),但只有一个矩阵和一个偏置,应将它们应用于批次中的每个向量。
  3. 使用 `jax.vmap` 创建一个 `batched_apply_affine`,以高效地将 `apply_affine` 应用于批次中的每个向量。
- 提示:`jax.vmap` 的 `in_axes` 应该为批处理的向量参数指定 0,为矩阵和偏置指定 `None`,因为它们没有被批处理(广播)。`out_axes` 应该为 0,以表示输出沿第一个轴进行批处理。
  1. 使用样本数据测试 `batched_apply_affine`。
# 练习 4 说明
key_ex4_main, main_key = jax.random.split(main_key)
key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)

# 1. 为单个向量定义 apply_affine
def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    # TODO: 计算 jnp.dot(matrix, vector) + bias
    result = None # 占位符
    return result

# 2. 准备数据
batch_size = 4
input_features = 3
output_features = 2

# batch_of_vectors: (batch_size, input_features)
# single_matrix: (output_features, input_features)
# single_bias: (output_features,)
batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))
single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))
single_bias = jax.random.normal(key_ex4_bias, (output_features,))


# 3. 使用 jax.vmap 创建 batched_apply_affine
# TODO: 正确指定 in_axes:向量是批处理的,矩阵和偏置不是。out_axes 应为 0。
batched_apply_affine = None # 占位符 jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)


# 4. 测试 batched_apply_affine
result_vmap = None # 占位符 batched_apply_affine(batch_of_vectors, single_matrix, single_bias)
print("一批向量的形状:", batch_of_vectors.shape)
print("单个矩阵的形状:", single_matrix.shape)
print("单个偏置的形状:", single_bias.shape)
if result_vmap is not None:
    print("使用 vmap 的结果形状:", result_vmap.shape) # 期望: (batch_size, output_features)

    # 为了比较,一个手动循环(效率较低):
    # manual_results = []
    # for i in range(batch_size):
    #     manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))
    # result_manual_loop = jnp.stack(manual_results)
    # assert jnp.allclose(result_vmap, result_manual_loop)
    # print("vmap 结果与手动循环结果匹配。")
else:
    print("result_vmap 是 None。")
# @title 解决方案 4: `jax.vmap` (自动向量化)
key_ex4_sol_main, main_key = jax.random.split(main_key)
key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)

# 1. 为单个向量定义 apply_affine
def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return jnp.dot(matrix, vector) + bias

# 2. 准备数据
batch_size_sol = 4
input_features_sol = 3
output_features_sol = 2

batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))
single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))
single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))

# 3. 使用 jax.vmap 创建 batched_apply_affine
# 向量沿轴 0 进行批处理,矩阵和偏置不进行批处理(广播)。
# out_axes=0 表示输出也将沿其第一个轴进行批处理。
batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)

# 4. 测试 batched_apply_affine
result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)
print("一批向量的形状:", batch_of_vectors_sol.shape)
print("单个矩阵的形状:", single_matrix_sol.shape)
print("单个偏置的形状:", single_bias_sol.shape)
print("使用 vmap 的结果形状:", result_vmap_sol.shape) # 期望: (batch_size, output_features)
assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)

# 为了比较,一个手动循环(效率较低):
manual_results_sol = []
for i in range(batch_size_sol):
    manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))
result_manual_loop_sol = jnp.stack(manual_results_sol)

assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)
print("\nvmap 结果与手动循环结果匹配,表明向量化正确。")

练习 5: Flax NNX - 定义模型

目标: 学习使用 Flax NNX 定义一个简单的神经网络模型。

说明:

  1. 定义一个 Flax NNX 模型类 `SimpleNNXModel`,它继承自 `nnx.Module`。
  2. 在其 `__init__` 方法中,定义一个 `nnx.Linear` 层。该层应接受 `din`(输入特征)和 `dout`(输出特征)作为参数。请记住将 `rngs` 参数传递给 `nnx.Linear` 以进行参数初始化(例如,`rngs=rngs`)。
  3. 实现 `__call__` 方法(前向传播),该方法接受输入 `x` 并将其传递给线性层。
  4. 实例化您的 `SimpleNNXModel`。您需要使用 JAX PRNG 密钥创建一个 `nnx.Rngs` 对象(例如,`nnx.Rngs(params=jax.random.key(seed))`)。密钥名称 `params` 是 `nnx.Linear` 的惯例。
  5. 使用虚拟输入批次测试您的模型实例。使用 `nnx.display()` 打印输出和模型的状态(参数)。
# 练习 5 说明
key_ex5_model_init, main_key = jax.random.split(main_key)

# 1, 2, 3. 定义 SimpleNNXModel
class SimpleNNXModel(nnx.Module):
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        # TODO: 定义一个名为 'dense_layer' 的 nnx.Linear 层
        # self.dense_layer = nnx.Linear(...)
        self.some_attribute = None # 占位符,稍后删除
        pass # 如果类不为空,则删除此占位符

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # TODO: 将输入 x 传递给 dense_layer
        # return self.dense_layer(x)
        return x # 占位符

# 4. 实例化模型
model_din = 3
model_dout = 2
# TODO: 为参数初始化创建 nnx.Rngs。使用 'params' 作为密钥名称。
model_rngs = None # 占位符 nnx.Rngs(params=key_ex5_model_init)
my_model = None # 占位符 SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)

# 5. 使用虚拟数据进行测试
dummy_batch_size = 4
dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))

model_output = None # 占位符
if my_model is not None:
    model_output = my_model(dummy_input_ex5)
    print(f"模型输出形状: {model_output.shape}")
    print(f"模型输出:\n{model_output}")

    model_state = my_model.get_state()
    print(f"\n模型状态 (参数等):")
    pprint.pprint(model_state)
else:
    print("my_model 是 None。")
# @title 解决方案 5: Flax NNX - 定义模型
key_ex5_sol_model_init, main_key = jax.random.split(main_key)

# 1, 2, 3. 定义 SimpleNNXModel
class SimpleNNXModel_Sol(nnx.Module): # 为解决方案单元重命名
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        # nnx.Linear 默认会使用 rngs 中的 'params' 密钥来初始化其参数
        self.dense_layer = nnx.Linear(din, dout, rngs=rngs)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.dense_layer(x)

# 4. 实例化模型
model_din_sol = 3
model_dout_sol = 2
# 为参数初始化创建 nnx.Rngs。
# 'params' 是 nnx.Linear 在 rngs 对象中查找的默认密钥。
model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)
my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)

# 5. 使用虚拟数据进行测试
dummy_batch_size_sol = 4
dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))

model_output_sol = my_model_sol(dummy_input_ex5_sol)
print(f"模型输出形状: {model_output_sol.shape}")
print(f"模型输出:\n{model_output_sol}")

# model_state_sol = my_model_sol.get_state()
_, model_state_sol = nnx.split(my_model_sol)
print(f"\n模型状态 (参数等):")
nnx.display(model_state_sol)

# 检查参数是否存在
assert 'dense_layer' in model_state_sol, "在 model_state 中找不到密钥 'dense_layer'"
assert 'kernel' in model_state_sol['dense_layer'], "在 model_state['dense_layer'] 中找不到密钥 'kernel'"
assert 'bias' in model_state_sol['dense_layer'], "在 model_state['dense_layer'] 中找不到密钥 'bias'"
print("\n模型参数(dense_layer 的内核和偏置)存在于状态中。")

练习 6: Optax 和 Flax NNX - 创建优化器

目标: 设置一个 Optax 优化器,并使用 `nnx.Optimizer` 将其包装起来,以便与 Flax NNX 模型一起使用。

说明:

  1. 使用前一个练习解决方案中的 `SimpleNNXModel_Sol` 类和实例 `my_model_sol`。(如果独立运行,请重新实例化它)。
  2. 创建一个 Optax 优化器,例如学习率为 0.001 的 `optax.adam`。
  3. 创建一个 `nnx.Optimizer` 实例。这个包装器将 Optax 优化器与您的 Flax NNX 模型 (`my_model_sol`) 连接起来。
  4. 打印 `nnx.Optimizer` 实例及其 `state` 属性,以查看初始化的优化器状态(例如,Adam 的动量项)。
# 练习 6 说明

# 1. 假设 my_model_sol 可从练习 5 的解决方案中获得
# (如果独立运行,请重新实例化它)
if 'my_model_sol' not in globals():
    print("为 Ex6 重新初始化 Ex5 解决方案中的模型。")
    key_ex6_model_init, main_key = jax.random.split(main_key)
    _model_din_ex6 = 3
    _model_dout_ex6 = 2
    _model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)
    # 如果已定义,则使用解决方案类名,否则使用学生的类名
    _ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
    model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)
    print("为优化器创建了模型。")
else:
    model_for_opt = my_model_sol # 使用上一个解决方案中的模型
    print("使用上一个练习中的 'my_model_sol' 作为 'model_for_opt'。")


# 2. 创建一个 Optax 优化器
learning_rate = 0.001
# TODO: 创建一个 optax.adam 优化器转换
optax_tx = None # 占位符 optax.adam(...)

# 3. 创建一个 nnx.Optimizer 包装器
# TODO: 包装模型 (model_for_opt) 和 optax 转换 (optax_tx)
# 现在需要 `wrt` 参数来指定要对什么进行微分。
nnx_optimizer = None # 占位符 nnx.Optimizer(...)

# 4. 打印优化器及其状态
print("\nFlax NNX Optimizer 包装器:")
nnx.display(nnx_optimizer)

print("\n初始优化器状态 (Optax 状态,例如 Adam 的动量):")
if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):
   pprint.pprint(nnx_optimizer.state)
   # if hasattr(nnx_optimizer, 'opt_state'):
   #     adam_state = nnx_optimizer.opt_state
   #     assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')
   #     print("\n优化器状态结构对于 Adam 似乎是合理的。")
else:
   print("nnx_optimizer 或其状态为 None 或结构不符合预期。")
# @title 解决方案 6: Optax 和 Flax NNX - 创建优化器

# 1. 使用练习 5 解决方案中的 my_model_sol
# 如果不是按顺序运行,请确保已定义 my_model_sol:
if 'my_model_sol' not in globals():
    print("为 Ex6 解决方案重新初始化 Ex5 解决方案中的模型。")
    key_ex6_sol_model_init, main_key = jax.random.split(main_key)
    _model_din_sol_ex6 = 3
    _model_dout_sol_ex6 = 2
    _model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)
    # 确保使用 SimpleNNXModel_Sol
    my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)
    print("为优化器重新创建了模型 'my_model_sol'。")
else:
    print("使用上一个练习中的模型 'my_model_sol'。")


# 2. 创建一个 Optax 优化器
learning_rate_sol = 0.001
# 创建一个 optax.adam 优化器转换
optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)

# 3. 创建一个 nnx.Optimizer 包装器
# 这将模型和 Optax 优化器连接起来。
# 优化器状态将根据模型的参数进行初始化。
nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)

# 4. 打印优化器及其状态
print("\nFlax NNX Optimizer 包装器:")
nnx.display(nnx_optimizer_sol) # 显示它关联的模型和 Optax 转换

print("\n初始优化器状态 (Optax 状态,例如 Adam 的动量):")
# nnx.Optimizer 将实际的 Optax 状态存储在其 .opt_state 属性中。
# 此状态是一个与模型参数结构匹配的 PyTree。
pprint.pprint(nnx_optimizer_sol.opt_state)

# 验证 Adam 的优化器状态结构(每个参数的 count、mu、nu)
assert hasattr(nnx_optimizer_sol, 'opt_state'), "在 nnx.Optimizer 中找不到 Optax opt_state"
# opt_state 是一个元组,对于 adam 通常是 (CountState(), ScaleByAdamState())
adam_optax_internal_state = nnx_optimizer_sol.opt_state
assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), "未找到 Adam 'count' 状态。"
# 元组的第二个元素通常是参数特定状态(如 mu 和 nu)所在的位置
if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):
    param_specific_state = adam_optax_internal_state[1]
    assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], "未找到内核的 Adam 'mu' 状态。"
    print("\n优化器状态结构对于 Adam 是正确的。")
else:
    print("\n警告:Adam 的优化器状态结构可能不同或未完全验证。")

练习 7: 使用 Flax NNX 和 Optax 进行训练步骤

目标: 为 Flax NNX 模型实现一个完整的、JIT 编译的训练步骤。

说明:

  1. 您需要:
- 您的模型类的一个实例(例如,来自 Ex 5/6 解决方案的 `my_model_sol`)。 - `nnx.Optimizer` 的一个实例(例如,来自 Ex 6 解决方案的 `nnx_optimizer_sol`)。
  1. 定义一个使用 `@nnx.jit` 装饰的 `train_step` 函数。此函数应将模型、优化器、输入 `x_batch` 和目标 `y_batch` 作为参数。
  2. 在 `train_step` 内部:
- 定义一个内部的 `loss_fn_for_grad`。此函数必须将模型作为其第一个参数。在内部,它计算模型对 `x_batch` 的预测,然后计算与 `y_batch` 的均方误差 (MSE)。 - 使用 `nnx.value_and_grad(loss_fn_for_grad)(model_arg)` 计算损失值和相对于传递给 `loss_fn_for_grad` 的模型的梯度。(`model_arg` 是传递给 `train_step` 的模型实例)。 - 使用 `optimizer_arg.update(model_arg, grads)` 更新模型的参数(和优化器的状态)。`update` 方法接受模型和梯度,并就地更新模型的状态。 - 返回计算出的 `loss_value`。
  1. 创建虚拟的 `x_batch` 和 `y_batch` 数据。
  2. 调用您的 `train_step` 函数。打印返回的损失。
  3. (可选) 通过比较调用前后参数值的变化,验证模型的参数在 `train_step` 后是否已更改。
# 练习 7 说明
key_ex7_main, main_key = jax.random.split(main_key)
key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)

# 1. 使用上一个练习解决方案中的模型和优化器
# 确保 my_model_sol 和 nnx_optimizer_sol 可用
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
    print("为 Ex7 重新初始化 Ex5/Ex6 解决方案中的模型和优化器。")
    key_ex7_model_fallback, main_key = jax.random.split(main_key)
    _model_din_ex7 = 3
    _model_dout_ex7 = 2
    _model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)
    # 确保使用 SimpleNNXModel_Sol
    my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)
    _optax_tx_ex7 = optax.adam(learning_rate=0.001)
    nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)
    print("为 Ex7 重新创建了模型和优化器。")
else:
    my_model_ex7 = my_model_sol
    nnx_optimizer_ex7 = nnx_optimizer_sol
    print("使用 'my_model_sol' 和 'nnx_optimizer_sol' 作为 'my_model_ex7' 和 'nnx_optimizer_ex7'。")


# 2, 3. 定义 train_step 函数
# TODO: 使用 @nnx.jit 装饰
# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # 使用 nnx.Module 基类进行类型提示
#                x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:

    # TODO: 定义内部 loss_fn_for_grad(current_model_state_for_grad_fn)
    # def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # 使用 nnx.Module 基类进行类型提示
        # y_pred = model_in_grad_fn(x_batch)
        # loss = jnp.mean((y_pred - y_batch)**2)
        # return loss
    #    return jnp.array(0.0) # 占位符

    # TODO: 使用 nnx.value_and_grad 计算损失值和梯度
    # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # 传递 model_arg

    # TODO: 更新优化器(这将就地更新 model_arg)
    # optimizer_arg.update(model_arg, grads)

    # return loss_value
#    return jnp.array(0.0) # 占位符定义的 train_step 函数

# 供学生定义:
# 确保函数签名对于 nnx.jit 是正确的
@nnx.jit
def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
               x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
    # 供学生使用的占位符实现
    def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
        # y_pred = model_in_grad_fn(x_batch)
        # loss = jnp.mean((y_pred - y_batch)**2)
        # return loss
        return jnp.array(0.0) # 学生 TODO:替换此行

    # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
    # optimizer_arg.update(grads)
    # return loss_value
    return jnp.array(-1.0) # 学生 TODO:替换此行


# 4. 创建虚拟数据
batch_s = 8
# 仔细访问 features_in 和 features_out
_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3
_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2

x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))
y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))

# 可选:存储初始参数值以进行比较
initial_kernel_val = None
if hasattr(my_model_ex7, 'get_state'):
    _current_model_state_ex7 = my_model_ex7.get_state()
    if 'dense_layer' in _current_model_state_ex7:
       initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()
print(f"初始内核值(样本): {initial_kernel_val}")

# 5. 调用 train_step
# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # 学生将取消注释
loss_after_step = jnp.array(-1.0) # 占位符,直到学生实现 train_step
if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # 检查学生是否已实现
    loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)
    print(f"一个训练步骤后的损失: {loss_after_step}")
else:
    print("学生需要实现 `train_step` 函数。")


# # 6. 可选:验证参数更改
# updated_kernel_val_sol = None
# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # 更新后再次获取状态
# if 'dense_layer' in updated_model_state_sol:
#   updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
#   print(f"更新后的内核值(样本): {updated_kernel_val_sol}")

# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
#     assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "内核参数未更改!"
#     print("内核参数在训练步骤后按预期更改。")
# else:
#     print("无法验证内核更改(初始值或更新值为 None)。")
# @title 解决方案 7: 使用 Flax NNX 和 Optax 进行训练步骤
key_ex7_sol_main, main_key = jax.random.split(main_key)
key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)

# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
    print("为 Ex7 解决方案重新初始化 Ex5/Ex6 解决方案中的模型和优化器。")
    key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)
    _model_din_sol_ex7 = 3
    _model_dout_sol_ex7 = 2
    _model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)
    # 确保为解决方案使用 SimpleNNXModel_Sol
    my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)
    _optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)
    nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)
    print("为 Ex7 解决方案重新创建了模型和优化器。")
else:
    # 如果按顺序运行解决方案,这些将是正确的实例
    my_model_sol_ex7 = my_model_sol
    nnx_optimizer_sol_ex7 = nnx_optimizer_sol
    print("为 Ex7 解决方案使用 'my_model_sol' 和 'nnx_optimizer_sol'。")


# 2, 3. 定义 train_step 函数
@nnx.jit # 使用 @nnx.jit 装饰器进行 JIT 编译
def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # 为通用性使用基类 nnx.Module
                   x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:

    # 定义内部的 loss_fn_for_grad。它将模型作为其第一个参数。
    # 它从外部作用域捕获 x_batch 和 y_batch。
    def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # 使用基类 nnx.Module
        y_pred = model_in_grad_fn(x_batch) # 使用传递给此内部函数的模型
        loss = jnp.mean((y_pred - y_batch)**2)
        return loss

    # 使用 nnx.value_and_grad 计算损失值和梯度。
    # 这将对 loss_fn_for_grad 关于其第一个参数 (model_in_grad_fn) 进行微分。
    # 我们将模型的当前状态 (model_arg) 传递给它。
    loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)

    # 更新优化器。这将就地更新 model_arg(nnx_optimizer_sol_ex7 所引用的)。
    optimizer_arg.update(model_arg, grads)

    return loss_value


# 4. 创建虚拟数据
batch_s_sol = 8
# 确保 din 和 dout 与 Ex5/Ex6 的模型实例化匹配
# my_model_sol_ex7.dense_layer 是一个 nnx.Linear 对象
din_from_model_sol = my_model_sol_ex7.dense_layer.in_features
dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features

x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))
y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))

# 可选:存储初始参数值以进行比较
initial_kernel_val_sol = None
_, current_model_state_sol = nnx.split(my_model_sol_ex7)
if 'dense_layer' in current_model_state_sol:
    initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()
print(f"初始内核值(样本): {initial_kernel_val_sol}")


# 5. 调用 train_step
# 第一次调用将 JIT 编译 train_step_sol 函数。
loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"一个训练步骤后的损失 (第一次调用, JIT): {loss_after_step_sol}")
# 第二次调用以显示它更快(尽管 %timeit 更适合测量)
loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"一个训练步骤后的损失 (第二次调用, 缓存): {loss_after_step_sol_2}")


# 6. 可选:验证参数更改
updated_kernel_val_sol = None
_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # 更新后再次获取状态
if 'dense_layer' in updated_model_state_sol:
  updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
  print(f"更新后的内核值(样本): {updated_kernel_val_sol}")

if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
    assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "内核参数未更改!"
    print("内核参数在训练步骤后按预期更改。")
else:
    print("无法验证内核更改(初始值或更新值为 None)。")

练习 8: Orbax - 保存和恢复检查点

目标: 学习使用 Orbax 保存和恢复 JAX PyTree,特别是 Flax NNX 模型状态和 Optax 优化器状态。

说明:

  1. 您需要上一个练习解决方案中的模型(例如 `my_model_sol_ex7`)和优化器(例如 `nnx_optimizer_sol_ex7`)。
  2. 定义一个检查点目录(例如 `/tmp/my_nnx_checkpoint/`)。
  3. 创建一个 Orbax `CheckpointManagerOptions`,然后创建一个 `CheckpointManager`。
  4. 将要保存的状态捆绑到一个字典中。对于 NNX,这是模型的 `my_model_sol_ex7.get_state()` 和优化器内部状态的 `nnx_optimizer_sol_ex7.state`。还要包括一个训练步骤计数器。
  5. 使用 `checkpoint_manager.save()` 和 `ocp.args.StandardSave()` 保存捆绑的状态。调用 `checkpoint_manager.wait_until_finished()` 以确保保存完成。
  6. 要恢复:
- 为您的模型 (`restored_model`) 和 Optax 转换 (`restored_optax_tx`) 创建新实例。新模型应具有不同的 PRNG 密钥以用于其初始参数,以证明恢复有效。 - 使用 `checkpoint_manager.restore()` 和 `ocp.args.StandardRestore()` 加载捆绑的状态。 - 使用 `restored_model.update_state(loaded_bundle['model'])` 将加载的模型状态应用于 `restored_model`。 - 创建一个新的 `nnx.Optimizer` (`restored_optimizer`),关联 `restored_model` 和 `restored_optax_tx`。 - 将加载的优化器状态分配给新的优化器:`restored_optimizer.state = loaded_bundle['optimizer']`。
  1. 验证 `restored_model` 中的参数与原始 `my_model_sol_ex7`(保存前或来自已保存状态)中的相应参数匹配。如果可能,还要比较优化器状态。
  2. 清理检查点目录。
# 练习 8 说明
# import orbax.checkpoint as ocp # 已导入
# import os, shutil # 已导入

# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
    print("为 Ex8 重新初始化 Ex7 解决方案中的模型和优化器。")
    key_ex8_model_fallback, main_key = jax.random.split(main_key)
    _model_din_ex8 = 3
    _model_dout_ex8 = 2
    _model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)
    _ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
    model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)
    _optax_tx_ex8 = optax.adam(learning_rate=0.001)
    optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)
    print("为 Ex8 重新创建了模型和优化器。")
else:
    model_to_save = my_model_sol_ex7
    optimizer_to_save = nnx_optimizer_sol_ex7
    print("使用 Ex7 解决方案中的模型和优化器。")

# 2. 定义检查点目录
# TODO: 定义 checkpoint_dir
checkpoint_dir = None # 占位符,例如 "/tmp/my_nnx_checkpoint_exercise/"
# if checkpoint_dir and os.path.exists(checkpoint_dir):
#    shutil.rmtree(checkpoint_dir) # 为安全起见,清理以前的运行
# if checkpoint_dir:
#    os.makedirs(checkpoint_dir, exist_ok=True)


# 3. 创建 Orbax CheckpointManager
# TODO: 创建选项和管理器
# options = ocp.CheckpointManagerOptions(...)
# mngr = ocp.CheckpointManager(...)
options = None
mngr = None

# 4. 捆绑状态
# current_step = 100 # 示例步骤
# TODO: 获取 model_state 和 optimizer_state
# model_state_to_save = nnx.split(model_to_save)
# 优化器状态现在通过 .state 属性访问。
# opt_state_to_save = optimizer_to_save.state
# save_bundle = {
#     'model': model_state_to_save,
#     'optimizer': opt_state_to_save,
#     'step': current_step
# }
save_bundle = None

# 5. 保存检查点
# if mngr and save_bundle:
#   TODO: 保存检查点
#   mngr.save(...)
#   mngr.wait_until_finished()
#   print(f"检查点已在步骤 {current_step} 保存到 {checkpoint_dir}")
# else:
#   print("检查点管理器或 save_bundle 未初始化。")

# --- 恢复 ---
# 6.a 创建新模型和 Optax 转换(用于恢复)
# key_ex8_restore_model, main_key = jax.random.split(main_key)
# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3
# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2
# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
# restored_model = _ModelClassRestore(
#     din=din_restore, dout=dout_restore,
#     rngs=nnx.Rngs(params=key_ex8_restore_model) # 用于不同初始参数的新密钥
# )
# restored_optax_tx = optax.adam(learning_rate=0.001) # 相同的 Optax 配置
restored_model = None
restored_optax_tx = None

# 6.b 恢复检查点
# loaded_bundle = None
# if mngr:
#   TODO: 恢复检查点
#   latest_step = mngr.latest_step()
#   if latest_step is not None:
#       loaded_bundle = mngr.restore(...)
#       print(f"已从步骤 {latest_step} 恢复检查点")
#   else:
#       print("未找到要恢复的检查点。")
# else:
#   print("检查点管理器未初始化以进行恢复。")

# 6.c 应用加载的状态
# if loaded_bundle and restored_model:
#   TODO: 更新 restored_model 状态
#   nnx.update(restored_model, ...)
#   print("已应用恢复的模型状态。")

    # TODO: 创建新的 nnx.Optimizer 并分配其状态
#   restored_optimizer = nnx.Optimizer(...)
#   restored_optimizer.state = ...
#   print("已应用恢复的优化器状态。")
# else:
#   print("loaded_bundle 或 restored_model 为 None,无法应用状态。")
restored_optimizer = None

# 7. 验证恢复
# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
# _, restored_model_state = nnx.split(restored_model_sol)
# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
#        "恢复后模型内核参数不同!"
# print("\n模型参数已成功恢复和验证(内核匹配)。")

# # 验证优化器状态(例如,特定参数的 Adam 'mu')
# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
#                        "优化器 Adam mu for kernel 不同!"
# print("优化器状态(样本 mu)已成功恢复和验证。")


# 8. 清理
# if mngr:
#   mngr.close()
# if checkpoint_dir and os.path.exists(checkpoint_dir):
#   shutil.rmtree(checkpoint_dir)
#   print(f"已清理检查点目录:{checkpoint_dir}")
# @title 解决方案 8: Orbax - 保存和恢复检查点

# 1. 使用上一个练习解决方案中的模型和优化器
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
    print("为 Ex8 解决方案重新初始化 Ex7 解决方案中的模型和优化器。")
    key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)
    _model_din_sol_ex8 = 3
    _model_dout_sol_ex8 = 2
    _model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)
    # 确保为解决方案使用 SimpleNNXModel_Sol
    model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,
                                           dout=_model_dout_sol_ex8,
                                           rngs=_model_rngs_sol_ex8)
    _optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # 存储转换以供稍后使用
    optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)
    print("为 Ex8 解决方案重新创建了模型和优化器。")
else:
    model_to_save_sol = my_model_sol_ex7
    optimizer_to_save_sol = nnx_optimizer_sol_ex7
    # 我们需要用于创建优化器的 optax 转换来进行恢复
    _optax_tx_sol_ex8 = optimizer_to_save_sol.tx # 访问原始 Optax 转换
    print("为 Ex8 解决方案使用 Ex7 解决方案中的模型和优化器。")

# 2. 定义检查点目录
checkpoint_dir_sol = "/tmp/my_nnx_checkpoint_exercise_solution/"
if os.path.exists(checkpoint_dir_sol):
   shutil.rmtree(checkpoint_dir_sol) # 清理以前的运行
os.makedirs(checkpoint_dir_sol, exist_ok=True)
print(f"Orbax 检查点目录: {checkpoint_dir_sol}")

# 3. 创建 Orbax CheckpointManager
options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)
mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)

# 4. 捆绑状态
current_step_sol = 100 # 示例步骤
_, model_state_to_save_sol = nnx.split(model_to_save_sol)
# 优化器状态现在是 .state 属性中直接可用的 PyTree。
opt_state_to_save_sol = optimizer_to_save_sol.opt_state
save_bundle_sol = {
    'model': model_state_to_save_sol,
    'optimizer': opt_state_to_save_sol,
    'step': current_step_sol
}
print("\n要保存的状态包:")
pprint.pprint(f"模型状态键: {model_state_to_save_sol.keys()}")
pprint.pprint(f"优化器状态类型: {type(opt_state_to_save_sol)}")


# 5. 保存检查点
mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))
mngr_sol.wait_until_finished()
print(f"\n检查点已在步骤 {current_step_sol} 保存到 {checkpoint_dir_sol}")

# --- 恢复 ---
# 6.a 创建新模型和 Optax 转换(用于恢复)
key_ex8_sol_restore_model, main_key = jax.random.split(main_key)
# 确保如果可能,从已保存模型的结构中正确获取 din/dout
# 假设 model_to_save_sol 是具有 dense_layer 的 SimpleNNXModel_Sol
din_restore_sol = model_to_save_sol.dense_layer.in_features
dout_restore_sol = model_to_save_sol.dense_layer.out_features

restored_model_sol = SimpleNNXModel_Sol( # 使用解决方案的模型类
    din=din_restore_sol, dout=dout_restore_sol,
    rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # 用于不同初始参数的新密钥
)
# 我们需要原始的 Optax 转换定义来创建新的 nnx.Optimizer
# _optax_tx_sol_ex8 之前已存储,或者如果配置已知,则可以重新创建
restored_optax_tx_sol = _optax_tx_sol_ex8

# 在恢复前打印新模型的参数,以显示它不同
_, kernel_before_restore_sol = nnx.split(restored_model_sol)
print(f"\n恢复前 'restored_model_sol' 的示例内核:")
nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])

# 6.b 恢复检查点
loaded_bundle_sol = None
latest_step_sol = mngr_sol.latest_step()
if latest_step_sol is not None:
    # 对于 NNX,我们恢复原始 PyTree,StandardRestore 是合适的。
    loaded_bundle_sol = mngr_sol.restore(latest_step_sol,
                                         args=ocp.args.StandardRestore(save_bundle_sol))
    print(f"\n已从步骤 {latest_step_sol} 恢复检查点")
    print(f"加载的包包含键: {loaded_bundle_sol.keys()}")
else:
    raise ValueError("未找到要恢复的检查点。")

# 6.c 应用加载的状态
assert loaded_bundle_sol is not None, "加载的包为 None"
nnx.update(restored_model_sol, loaded_bundle_sol['model'])
print("已将恢复的模型状态应用于 'restored_model_sol'。")

# 使用 restored_model 和原始 optax_tx 创建新的 nnx.Optimizer
restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,
                                       wrt=nnx.Param)
# 现在分配加载的 Optax 状态 PyTree
restored_optimizer_sol.state = loaded_bundle_sol['optimizer']
print("已将恢复的优化器状态应用于 'restored_optimizer_sol'。")


# 7. 验证恢复
original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
_, restored_model_state = nnx.split(restored_model_sol)
kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
       "恢复后模型内核参数不同!"
print("\n模型参数已成功恢复和验证(内核匹配)。")

# 验证优化器状态(例如,特定参数的 Adam 'mu')
original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
                       "优化器 Adam mu for kernel 不同!"
print("优化器状态(样本 mu)已成功恢复和验证。")


# 8. 清理
mngr_sol.close()
if os.path.exists(checkpoint_dir_sol):
  shutil.rmtree(checkpoint_dir_sol)
  print(f"已清理检查点目录:{checkpoint_dir_sol}")

结论

恭喜您完成 JAX AI 技术栈练习!

您现在已经亲身体验了:

  • JAX 核心:jax.numpy、函数式编程、jax.jit、jax.grad、jax.vmap。
  • Flax NNX:定义和实例化 Pythonic 的神经网络模型。
  • Optax:使用 Flax NNX 创建和使用可组合的优化器。
  • 训练循环:在 Flax NNX 中实现端到端的训练步骤。
  • Orbax:保存和恢复模型和优化器状态。

这为使用 JAX 生态系统开发高性能机器学习模型奠定了坚实的基础。

要进一步学习,请参阅官方文档:

  • JAX AI 技术栈:https://jaxstack.ai
  • JAX: https://jax.dev
  • Flax NNX: https://flax.readthedocs.io
  • Optax: https://optax.readthedocs.io
  • Orbax: https://orbax.readthedocs.io
  • 不要忘记对培训课程提供反馈: https://goo.gle/jax-training-feedback