欢迎!本笔记本包含练习,可帮助您根据讲座中涵盖的概念,练习将 Chex 与 JAX 和 Flax NNX 结合使用。
目标: 巩固您对 Chex 如何增强基于 JAX 的项目的可靠性和可调试性的理解。 说明:TODO 部分。让我们开始吧!
# 首先运行此单元格以安装和导入必要的库。
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import chex
import flax
from flax import nnx
import functools # 用于 functools.partial
# 用于重置 assert_max_traces 练习的跟踪计数器的辅助函数
def reset_trace_counter():
chex.clear_trace_counter()
# 对于某些 JAX 版本,如果您频繁地重新运行单元格,可能需要一个小技巧来完全重置
# JAX 内部缓存。如果按顺序运行单元格,这些练习通常不需要这样做。
print(f"JAX 版本: {jax.__version__}")
print(f"Chex 版本: {chex.__version__}")
print(f"Flax 版本: {flax.__version__}")
print(f"正在运行: {jax.default_backend()}")
chex.assert_shape 和 chex.assert_typeprocess_data 函数。
input_array 的形状是否为 (3, None)input_array 的数据类型是否为 jnp.float32。output_array 的形状是否为 (3, 1)。def process_data_v1(input_array: chex.Array) -> chex.Array:
"""处理一个数组,断言形状和类型。"""
# TODO: 断言 input_array 的形状是 (3, None)
chex.assert_shape(input_array, <TODO>)
# TODO: 断言 input_array 的类型是 jnp.float32
chex.assert_type(<TODO>)
# 模拟一些将最后一个维度减少到 1 的处理
output_array = input_array[:, :1] * 2.0
# TODO: 断言 output_array 的形状是 (3, 1)
chex.assert_shape(output_array, (3, 1))
return output_array
# 测试用例
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入进行测试...")
result = process_data_v1(valid_input)
print(f"成功处理有效输入。输出形状: {result.shape}\n")
print("正在使用无效形状输入进行测试...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_v1(invalid_shape_input)
except AssertionError as e:
print(f"捕获到无效形状的预期错误:\n{e}\n")
print("正在使用无效类型输入进行测试...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
process_data_v1(invalid_type_input)
except AssertionError as e:
print(f"捕获到无效类型的预期错误: {e}\n")
def process_data_v1(input_array: chex.Array) -> chex.Array:
"""处理一个数组,断言形状和类型。"""
# TODO: 断言 input_array 的形状是 (3, None)
chex.assert_shape(input_array, (3, None))
# TODO: 断言 input_array 的类型是 jnp.float32
chex.assert_type(input_array, expected_types=jnp.float32)
# 模拟一些将最后一个维度减少到 1 的处理
output_array = input_array[:, :1] * 2.0
# TODO: 断言 output_array 的形状是 (3, 1)
chex.assert_shape(output_array, (3, 1))
return output_array
# 测试用例
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入进行测试...")
result = process_data_v1(valid_input)
print(f"成功处理有效输入。输出形状: {result.shape}\n")
print("正在使用无效形状输入进行测试...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_v1(invalid_shape_input)
except AssertionError as e:
print(f"捕获到无效形状的预期错误:\n{e}\n")
print("正在使用无效类型输入进行测试...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
process_data_v1(invalid_type_input)
except AssertionError as e:
print(f"捕获到无效类型的预期错误: {e}\n")
chex.assert_rank 和 chex.assert_scalarprocess_data_v2 函数。
matrix_input 是一个二维数组 (秩为 2)。scalar_input 是一个标量。result 也是一个二维数组。def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
"""处理一个矩阵和一个标量。"""
# TODO: 断言 matrix_input 的秩为 2
chex.assert_rank(matrix_input, <TODO>)
# TODO: 断言 scalar_input 是一个标量
chex.assert_scalar(<TODO>)
result = matrix_input * scalar_input + 1.0
# TODO: 断言 result 的秩为 2
chex.assert_rank(result, <TODO>)
return result
# 测试用例
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))
print("正在使用有效的秩/标量输入进行测试...")
try:
res_valid = process_data_v2(matrix, scalar)
print(f"成功处理有效的秩/标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到有效秩/标量的意外错误:\n{e}\n")
print("正在使用无效秩输入进行测试...")
try:
process_data_v2(not_a_matrix, scalar)
print(f"成功处理无效秩。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到无效秩的预期错误:\n{e}\n")
print("正在使用非标量输入进行测试...")
try:
process_data_v2(matrix, not_a_scalar)
print(f"成功处理非标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到非标量的预期错误:\n{e}\n")
def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
"""处理一个矩阵和一个标量。"""
# TODO: 断言 matrix_input 的秩为 2
chex.assert_rank(matrix_input, expected_ranks=2)
# TODO: 断言 scalar_input 是一个标量
chex.assert_scalar(scalar_input)
result = matrix_input * scalar_input + 1.0
# TODO: 断言 result 的秩为 2
chex.assert_rank(result, expected_ranks=2)
return result
# 测试用例
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))
print("正在使用有效的秩/标量输入进行测试...")
try:
res_valid = process_data_v2(matrix, scalar)
print(f"成功处理有效的秩/标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到有效秩/标量的意外错误:\n{e}\n")
print("正在使用无效秩输入进行测试...")
try:
process_data_v2(not_a_matrix, scalar)
print(f"成功处理无效秩。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到无效秩的预期错误:\n{e}\n")
print("正在使用非标量输入进行测试...")
try:
process_data_v2(matrix, not_a_scalar)
print(f"成功处理非标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
print(f"捕获到非标量的预期错误:\n{e}\n")
assert_trees_all_close, assert_tree_all_finite)def process_pytree(tree1, tree2):
"""
检查两个 PyTree 是否接近,以及第一个 PyTree 是否有限。
返回一个新树,其中元素是 tree1 + tree2。
"""
# TODO: 断言 tree1 和 tree2(近似)相等。使用较小的容差。
chex.assert_trees_all_close(<TODO> rtol=1e-5, atol=1e-8)
# TODO: 断言 tree1 中的所有元素都是有限的 (不是 NaN 或 Inf)。
chex.assert_tree_all_finite(<TODO>)
# 执行一些操作
return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)
# 测试用例
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}
print("正在使用近似和有限的 PyTree 进行测试...")
result_valid = process_pytree(tree_a, tree_b_close)
print("成功处理有效的 PyTree。\n")
print("正在使用非近似的 PyTree 进行测试...")
try:
process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
print(f"捕获到非近似树的预期错误:\n\n{e}\n")
print("正在使用非有限的 PyTree 进行测试...")
try:
process_pytree(tree_d_nan, tree_b_close) # tree_d_nan 将被检查是否有限
except AssertionError as e:
print(f"捕获到非有限树的预期错误:\n\n{e}\n")
def process_pytree(tree1, tree2):
"""
检查两个 PyTree 是否接近,以及第一个 PyTree 是否有限。
返回一个新树,其中元素是 tree1 + tree2。
"""
# TODO: 断言 tree1 和 tree2(近似)相等。使用较小的容差。
chex.assert_trees_all_close(tree1, tree2, rtol=1e-5, atol=1e-8)
# TODO: 断言 tree1 中的所有元素都是有限的 (不是 NaN 或 Inf)。
chex.assert_tree_all_finite(tree1)
# 执行一些操作
return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)
# 测试用例
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}
print("正在使用近似和有限的 PyTree 进行测试...")
result_valid = process_pytree(tree_a, tree_b_close)
print("成功处理有效的 PyTree。\n")
print("正在使用非近似的 PyTree 进行测试...")
try:
process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
print(f"捕获到非近似树的预期错误:\n\n{e}\n")
print("正在使用非有限的 PyTree 进行测试...")
try:
process_pytree(tree_d_nan, tree_b_close) # tree_d_nan 将被检查是否有限
except AssertionError as e:
print(f"捕获到非有限树的预期错误:\n\n{e}\n")
jax.jit 和 jax.vmap 等 JAX 转换中能够正确工作。
@jax.jit 内部的断言process_data_v1 函数。@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
"""JIT 编译版本的 process_data_v1 及其 Chex 断言。"""
# (断言在 process_data_v1 内部,我们在这里有效地重用它)
# 为清楚起见,我们在这里直接用断言重新定义它。
chex.assert_shape(input_array, (3, None))
chex.assert_type(input_array, jnp.float32)
output_array = input_array[:, :1] * 2.0
chex.assert_shape(output_array, (3, 1))
return output_array
# JIT 版本的测试用例
key = jax.random.PRNGKey(1) # 使用不同的密钥以获得可能不同的值
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入测试 JIT 函数...")
# 第一次调用将编译
result_jit = process_data_jitted(<TODO>)
print(f"成功处理 JIT 编译的有效输入。输出形状: {result_jit.shape}")
# 第二次调用使用缓存的编译
result_jit_cached = process_data_jitted(<TODO> * 2)
print(f"成功处理 JIT 编译的有效输入(已缓存)。输出形状: {result_jit_cached.shape}\n")
print("正在使用无效形状输入测试 JIT 函数...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_jitted(<TODO>)
except AssertionError as e:
print(f"捕获到无效形状的预期 JIT 错误:\n\n{e}\n")
@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
"""JIT 编译版本的 process_data_v1 及其 Chex 断言。"""
# (断言在 process_data_v1 内部,我们在这里有效地重用它)
# 为清楚起见,我们在这里直接用断言重新定义它。
chex.assert_shape(input_array, (3, None))
chex.assert_type(input_array, jnp.float32)
output_array = input_array[:, :1] * 2.0
chex.assert_shape(output_array, (3, 1))
return output_array
# JIT 版本的测试用例
key = jax.random.PRNGKey(1) # 使用不同的密钥以获得可能不同的值
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入测试 JIT 函数...")
# 第一次调用将编译
result_jit = process_data_jitted(valid_input_jit)
print(f"成功处理 JIT 编译的有效输入。输出形状: {result_jit.shape}")
# 第二次调用使用缓存的编译
result_jit_cached = process_data_jitted(valid_input_jit * 2)
print(f"成功处理 JIT 编译的有效输入(已缓存)。输出形状: {result_jit_cached.shape}\n")
print("正在使用无效形状输入测试 JIT 函数...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_jitted(invalid_shape_input_jit)
except AssertionError as e:
print(f"捕获到无效形状的预期 JIT 错误:\n\n{e}\n")
Chex 断言在 JIT 编译的函数中无缝工作,根据运行时传递的具体值捕获错误,即使检查是在编译后的代码中定义的。
---
@jax.vmap 进行多级验证(10,) 的一维数组。
process_single_item_vmap,它处理一个项目。item 的形状为 (10,)。
- 该函数应将项目的值加倍。
- 断言 result (process_single_item_vmap 的输出) 的形状也为 (10,)。
jax.vmap 创建 process_batch。process_batch 之前,断言 batch_input 的形状为 (BATCH_SIZE, 10)。process_batch 之后,断言 batch_output 的形状为 (BATCH_SIZE, 10)。BATCH_SIZE = 5
ITEM_SIZE = 10
def process_single_item_vmap(item: chex.Array) -> chex.Array:
"""处理单个项目,断言其形状。"""
# TODO: 断言单个项目的形状是 (ITEM_SIZE,)
chex.assert_shape(item, <TODO>)
result = item * 2.0
# TODO: 断言单个项目输出的形状是 (ITEM_SIZE,)
chex.assert_shape(result, <TODO>)
return result
# TODO: 使用 jax.vmap 对 process_single_item_vmap 函数进行向量化
process_batch = jax.vmap(<TODO>, in_axes=0, out_axes=0)
# 测试用例
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))
print("正在使用有效的批处理输入测试 vmap...")
# TODO: 在 vmap 之前断言完整 BATCHED 输入的形状
chex.assert_shape(valid_batch_input, <TODO>)
batch_output = process_batch(valid_batch_input)
# TODO: 在 vmap 之后断言完整 BATCHED 输出的形状
chex.assert_shape(batch_output, <TODO>)
print(f"Vmap 断言通过。输出形状: {batch_output.shape}\n")
print("正在使用批处理中无效的项目形状测试 vmap (来自 vmap 内部的错误)...")
try:
# 这将在 vmap 映射的函数 'process_single_item_vmap' 内部失败
process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
print(f"捕获到预期的 vmap 错误 (来自内部函数):\n{e}\n")
print("正在使用无效的批处理形状测试 vmap (来自外部断言的错误)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
# 这将在调用 process_batch 之前使断言失败
chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # 此行将失败
process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
print(f"捕获到预期的 vmap 错误 (来自外部断言):\n{e}\n")
BATCH_SIZE = 5
ITEM_SIZE = 10
def process_single_item_vmap(item: chex.Array) -> chex.Array:
"""处理单个项目,断言其形状。"""
# TODO: 断言单个项目的形状是 (ITEM_SIZE,)
chex.assert_shape(item, (ITEM_SIZE,))
result = item * 2.0
# TODO: 断言单个项目输出的形状是 (ITEM_SIZE,)
chex.assert_shape(result, (ITEM_SIZE,))
return result
# TODO: 使用 jax.vmap 对函数进行向量化
process_batch = jax.vmap(process_single_item_vmap, in_axes=0, out_axes=0)
# 测试用例
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))
print("正在使用有效的批处理输入测试 vmap...")
# TODO: 在 vmap 之前断言完整 BATCHED 输入的形状
chex.assert_shape(valid_batch_input, (BATCH_SIZE, ITEM_SIZE))
batch_output = process_batch(valid_batch_input)
# TODO: 在 vmap 之后断言完整 BATCHED 输出的形状
chex.assert_shape(batch_output, (BATCH_SIZE, ITEM_SIZE))
print(f"Vmap 断言通过。输出形状: {batch_output.shape}\n")
print("正在使用批处理中无效的项目形状测试 vmap (来自 vmap 内部的错误)...")
try:
# 这将在 vmap 映射的函数 'process_single_item_vmap' 内部失败
process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
print(f"捕获到预期的 vmap 错误 (来自内部函数):\n{e}\n")
print("正在使用无效的批处理形状测试 vmap (来自外部断言的错误)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
# 这将在调用 process_batch 之前使断言失败
chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # 此行将失败
process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
print(f"捕获到预期的 vmap 错误 (来自外部断言):\n{e}\n")
call 方法内。
SimpleMLP 模块:
call 中,验证输入 x:
[batch, features])。
- 特征维度(轴 1)必须与 self.linear1.in_features 匹配。
- 类型必须是 jnp.float32。
call 中,在返回之前验证输出 x:
self.linear2.out_features 匹配。
class SimpleMLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x: chex.Array) -> chex.Array:
# TODO: 验证输入 x
# - 必须是二维的 ([batch, features])
chex.assert_rank(<TODO>)
# - 特征维度 (轴 1) 必须与 self.linear1.in_features 匹配
chex.assert_axis_dimension(x, 1, <TODO>)
# - 类型必须是 jnp.float32
chex.assert_type(x, <TODO>)
# 前向传播
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
# TODO: 在返回前验证输出 x
# - 必须是二维的
chex.assert_rank(<TODO>)
# - 特征维度 (轴 1) 必须与 self.linear2.out_features 匹配
chex.assert_axis_dimension(x, 1, self.linear2.out_features)
return x
# SimpleMLP 的测试用例
key_nnx = nnx.Rngs(params=jax.random.key(0)) # 用于有状态操作的 NNX Rngs
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4
model = SimpleMLP(din, dmid, dout, rngs=key_nnx)
print("正在使用有效输入测试 NNX 模块:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O 检查通过。输出形状: {output_nnx.shape}\n")
print("正在使用无效输入秩测试 NNX 模块:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
model(x_invalid_rank_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入秩):\n{e}\n")
print("正在使用无效输入特征维度测试 NNX 模块:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
model(x_invalid_feat_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入特征):\n{e}\n")
print("正在使用无效输入类型测试 NNX 模块:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
model(x_invalid_type_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入类型):\n{e}\n")
class SimpleMLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x: chex.Array) -> chex.Array:
# TODO: 验证输入 x
# - 必须是二维的 ([batch, features])
chex.assert_rank(x, 2)
# - 特征维度 (轴 1) 必须与 self.linear1.in_features 匹配
chex.assert_axis_dimension(x, 1, self.linear1.in_features)
# - 类型必须是 jnp.float32
chex.assert_type(x, jnp.float32)
# 前向传播
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
# TODO: 在返回前验证输出 x
# - 必须是二维的
chex.assert_rank(x, 2)
# - 特征维度 (轴 1) 必须与 self.linear2.out_features 匹配
chex.assert_axis_dimension(x, 1, self.linear2.out_features)
return x
# SimpleMLP 的测试用例
key_nnx = nnx.Rngs(params=jax.random.key(0)) # 用于有状态操作的 NNX Rngs
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4
model = SimpleMLP(din, dmid, dout, rngs=key_nnx)
print("正在使用有效输入测试 NNX 模块:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O 检查通过。输出形状: {output_nnx.shape}\n")
print("正在使用无效输入秩测试 NNX 模块:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
model(x_invalid_rank_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入秩):\n{e}\n")
print("正在使用无效输入特征维度测试 NNX 模块:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
model(x_invalid_feat_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入特征):\n{e}\n")
print("正在使用无效输入类型测试 NNX 模块:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
model(x_invalid_type_nnx)
except AssertionError as e:
print(f"捕获到预期的 NNX 错误 (无效输入类型):\n{e}\n")
在组合多个层或更改模型配置时,这些断言如何帮助您及早发现错误?它们充当层之间以及模型外部 API 的契约。
---
jax.jit 和 jax.vmap 中的行为方式。@chex.chexify 的用途和用法(及其注意事项)。@chex.assert_max_traces 检测重新编译问题。持续使用 Chex 可以显著提高 JAX 项目的可靠性和可维护性。
进一步探索(可选):chex.chexify。chex.assert_devices_available)。@chex.variants。