Chex 练习:构建稳健的 JAX 和 Flax NNX 应用程序

欢迎!本笔记本包含练习,可帮助您根据讲座中涵盖的概念,练习将 Chex 与 JAX 和 Flax NNX 结合使用。

目标: 巩固您对 Chex 如何增强基于 JAX 的项目的可靠性和可调试性的理解。 说明:
  1. 阅读每个练习的题目描述。
  2. 用您的代码填写 TODO 部分。
  3. 运行单元格以测试您的解决方案。
  4. 将您的结果与提供的预期输出或提示进行比较。

让我们开始吧!

# 首先运行此单元格以安装和导入必要的库。
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import chex
import flax
from flax import nnx
import functools # 用于 functools.partial

# 用于重置 assert_max_traces 练习的跟踪计数器的辅助函数
def reset_trace_counter():
    chex.clear_trace_counter()
    # 对于某些 JAX 版本,如果您频繁地重新运行单元格,可能需要一个小技巧来完全重置
    # JAX 内部缓存。如果按顺序运行单元格,这些练习通常不需要这样做。

print(f"JAX 版本: {jax.__version__}")
print(f"Chex 版本: {chex.__version__}")
print(f"Flax 版本: {flax.__version__}")
print(f"正在运行: {jax.default_backend()}")

第 1 部分:核心 Chex 断言

Chex 提供了一套断言函数来验证数组属性。 让我们来练习最常用的一些。

练习 1.1: chex.assert_shapechex.assert_type

完成下面的 process_data 函数。
  • 添加断言以检查 input_array 的形状是否为 (3, None)
(表示 3 行,任意数量的列)。
  • 添加断言以检查 input_array 的数据类型是否为 jnp.float32
  • 添加断言以检查 output_array 的形状是否为 (3, 1)
def process_data_v1(input_array: chex.Array) -> chex.Array:
  """处理一个数组,断言形状和类型。"""
  # TODO: 断言 input_array 的形状是 (3, None)
  chex.assert_shape(input_array, <TODO>)

  # TODO: 断言 input_array 的类型是 jnp.float32
  chex.assert_type(<TODO>)

  # 模拟一些将最后一个维度减少到 1 的处理
  output_array = input_array[:, :1] * 2.0

  # TODO: 断言 output_array 的形状是 (3, 1)
  chex.assert_shape(output_array, (3, 1))

  return output_array

# 测试用例
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入进行测试...")
result = process_data_v1(valid_input)
print(f"成功处理有效输入。输出形状: {result.shape}\n")

print("正在使用无效形状输入进行测试...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
  process_data_v1(invalid_shape_input)
except AssertionError as e:
  print(f"捕获到无效形状的预期错误:\n{e}\n")

print("正在使用无效类型输入进行测试...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
  process_data_v1(invalid_type_input)
except AssertionError as e:
  print(f"捕获到无效类型的预期错误: {e}\n")

练习 1.1 解决方案

def process_data_v1(input_array: chex.Array) -> chex.Array:
  """处理一个数组,断言形状和类型。"""
  # TODO: 断言 input_array 的形状是 (3, None)
  chex.assert_shape(input_array, (3, None))

  # TODO: 断言 input_array 的类型是 jnp.float32
  chex.assert_type(input_array, expected_types=jnp.float32)

  # 模拟一些将最后一个维度减少到 1 的处理
  output_array = input_array[:, :1] * 2.0

  # TODO: 断言 output_array 的形状是 (3, 1)
  chex.assert_shape(output_array, (3, 1))

  return output_array

# 测试用例
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入进行测试...")
result = process_data_v1(valid_input)
print(f"成功处理有效输入。输出形状: {result.shape}\n")

print("正在使用无效形状输入进行测试...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
  process_data_v1(invalid_shape_input)
except AssertionError as e:
  print(f"捕获到无效形状的预期错误:\n{e}\n")

print("正在使用无效类型输入进行测试...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
  process_data_v1(invalid_type_input)
except AssertionError as e:
  print(f"捕获到无效类型的预期错误: {e}\n")

练习 1.2: chex.assert_rankchex.assert_scalar

完成 process_data_v2 函数。
  • 添加断言以确保 matrix_input 是一个二维数组 (秩为 2)。
  • 添加断言以确保 scalar_input 是一个标量。
  • 添加断言以确保 result 也是一个二维数组。
def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
  """处理一个矩阵和一个标量。"""
  # TODO: 断言 matrix_input 的秩为 2
  chex.assert_rank(matrix_input, <TODO>)

  # TODO: 断言 scalar_input 是一个标量
  chex.assert_scalar(<TODO>)

  result = matrix_input * scalar_input + 1.0

  # TODO: 断言 result 的秩为 2
  chex.assert_rank(result, <TODO>)
  return result

# 测试用例
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))

print("正在使用有效的秩/标量输入进行测试...")
try:
    res_valid = process_data_v2(matrix, scalar)
    print(f"成功处理有效的秩/标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
    print(f"捕获到有效秩/标量的意外错误:\n{e}\n")

print("正在使用无效秩输入进行测试...")
try:
  process_data_v2(not_a_matrix, scalar)
  print(f"成功处理无效秩。结果形状: {res_valid.shape}\n")
except AssertionError as e:
  print(f"捕获到无效秩的预期错误:\n{e}\n")

print("正在使用非标量输入进行测试...")
try:
  process_data_v2(matrix, not_a_scalar)
  print(f"成功处理非标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
  print(f"捕获到非标量的预期错误:\n{e}\n")

练习 1.2 解决方案

def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
  """处理一个矩阵和一个标量。"""
  # TODO: 断言 matrix_input 的秩为 2
  chex.assert_rank(matrix_input, expected_ranks=2)

  # TODO: 断言 scalar_input 是一个标量
  chex.assert_scalar(scalar_input)

  result = matrix_input * scalar_input + 1.0

  # TODO: 断言 result 的秩为 2
  chex.assert_rank(result, expected_ranks=2)
  return result

# 测试用例
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))

print("正在使用有效的秩/标量输入进行测试...")
try:
    res_valid = process_data_v2(matrix, scalar)
    print(f"成功处理有效的秩/标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
    print(f"捕获到有效秩/标量的意外错误:\n{e}\n")

print("正在使用无效秩输入进行测试...")
try:
  process_data_v2(not_a_matrix, scalar)
  print(f"成功处理无效秩。结果形状: {res_valid.shape}\n")
except AssertionError as e:
  print(f"捕获到无效秩的预期错误:\n{e}\n")

print("正在使用非标量输入进行测试...")
try:
  process_data_v2(matrix, not_a_scalar)
  print(f"成功处理非标量。结果形状: {res_valid.shape}\n")
except AssertionError as e:
  print(f"捕获到非标量的预期错误:\n{e}\n")

练习 1.3: PyTree 断言 (assert_trees_all_close, assert_tree_all_finite)

PyTree(嵌套的数组结构,如模型参数)在 JAX 中很常见。 Chex 为它们提供了断言。
def process_pytree(tree1, tree2):
  """
  检查两个 PyTree 是否接近,以及第一个 PyTree 是否有限。
  返回一个新树,其中元素是 tree1 + tree2。
  """
  # TODO: 断言 tree1 和 tree2(近似)相等。使用较小的容差。
  chex.assert_trees_all_close(<TODO> rtol=1e-5, atol=1e-8)

  # TODO: 断言 tree1 中的所有元素都是有限的 (不是 NaN 或 Inf)。
  chex.assert_tree_all_finite(<TODO>)

  # 执行一些操作
  return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)

# 测试用例
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}

print("正在使用近似和有限的 PyTree 进行测试...")
result_valid = process_pytree(tree_a, tree_b_close)
print("成功处理有效的 PyTree。\n")

print("正在使用非近似的 PyTree 进行测试...")
try:
  process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
  print(f"捕获到非近似树的预期错误:\n\n{e}\n")

print("正在使用非有限的 PyTree 进行测试...")
try:
  process_pytree(tree_d_nan, tree_b_close) # tree_d_nan 将被检查是否有限
except AssertionError as e:
  print(f"捕获到非有限树的预期错误:\n\n{e}\n")

练习 1.3 解决方案

def process_pytree(tree1, tree2):
  """
  检查两个 PyTree 是否接近,以及第一个 PyTree 是否有限。
  返回一个新树,其中元素是 tree1 + tree2。
  """
  # TODO: 断言 tree1 和 tree2(近似)相等。使用较小的容差。
  chex.assert_trees_all_close(tree1, tree2, rtol=1e-5, atol=1e-8)

  # TODO: 断言 tree1 中的所有元素都是有限的 (不是 NaN 或 Inf)。
  chex.assert_tree_all_finite(tree1)

  # 执行一些操作
  return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)

# 测试用例
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}

print("正在使用近似和有限的 PyTree 进行测试...")
result_valid = process_pytree(tree_a, tree_b_close)
print("成功处理有效的 PyTree。\n")

print("正在使用非近似的 PyTree 进行测试...")
try:
  process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
  print(f"捕获到非近似树的预期错误:\n\n{e}\n")

print("正在使用非有限的 PyTree 进行测试...")
try:
  process_pytree(tree_d_nan, tree_b_close) # tree_d_nan 将被检查是否有限
except AssertionError as e:
  print(f"捕获到非有限树的预期错误:\n\n{e}\n")

第 2 部分:Chex 断言与 JAX 转换

Chex 的一个主要优点是其断言在 jax.jitjax.vmap 等 JAX 转换中能够正确工作。

练习 2.1: @jax.jit 内部的断言

  • 使用练习 1.1 中的 process_data_v1 函数。
  • 对其进行 JIT 编译,并验证 Chex 断言是否仍然按预期工作。
@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
  """JIT 编译版本的 process_data_v1 及其 Chex 断言。"""
  # (断言在 process_data_v1 内部,我们在这里有效地重用它)
  # 为清楚起见,我们在这里直接用断言重新定义它。
  chex.assert_shape(input_array, (3, None))
  chex.assert_type(input_array, jnp.float32)
  output_array = input_array[:, :1] * 2.0
  chex.assert_shape(output_array, (3, 1))
  return output_array

# JIT 版本的测试用例
key = jax.random.PRNGKey(1) # 使用不同的密钥以获得可能不同的值
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入测试 JIT 函数...")

# 第一次调用将编译
result_jit = process_data_jitted(<TODO>)
print(f"成功处理 JIT 编译的有效输入。输出形状: {result_jit.shape}")

# 第二次调用使用缓存的编译
result_jit_cached = process_data_jitted(<TODO> * 2)
print(f"成功处理 JIT 编译的有效输入(已缓存)。输出形状: {result_jit_cached.shape}\n")

print("正在使用无效形状输入测试 JIT 函数...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
  process_data_jitted(<TODO>)
except AssertionError as e:
  print(f"捕获到无效形状的预期 JIT 错误:\n\n{e}\n")

练习 2.1 解决方案

@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
  """JIT 编译版本的 process_data_v1 及其 Chex 断言。"""
  # (断言在 process_data_v1 内部,我们在这里有效地重用它)
  # 为清楚起见,我们在这里直接用断言重新定义它。
  chex.assert_shape(input_array, (3, None))
  chex.assert_type(input_array, jnp.float32)
  output_array = input_array[:, :1] * 2.0
  chex.assert_shape(output_array, (3, 1))
  return output_array

# JIT 版本的测试用例
key = jax.random.PRNGKey(1) # 使用不同的密钥以获得可能不同的值
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("正在使用有效输入测试 JIT 函数...")

# 第一次调用将编译
result_jit = process_data_jitted(valid_input_jit)
print(f"成功处理 JIT 编译的有效输入。输出形状: {result_jit.shape}")

# 第二次调用使用缓存的编译
result_jit_cached = process_data_jitted(valid_input_jit * 2)
print(f"成功处理 JIT 编译的有效输入(已缓存)。输出形状: {result_jit_cached.shape}\n")

print("正在使用无效形状输入测试 JIT 函数...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
  process_data_jitted(invalid_shape_input_jit)
except AssertionError as e:
  print(f"捕获到无效形状的预期 JIT 错误:\n\n{e}\n")
观察:

Chex 断言在 JIT 编译的函数中无缝工作,根据运行时传递的具体值捕获错误,即使检查是在编译后的代码中定义的。

---

练习 2.2: 使用 @jax.vmap 进行多级验证

我们想要处理一批项目。每个项目都是一个形状为 (10,) 的一维数组。
  1. 定义 process_single_item_vmap,它处理一个项目。
- 在此函数内部,断言 item 的形状为 (10,)。 - 该函数应将项目的值加倍。 - 断言 result (process_single_item_vmap 的输出) 的形状也为 (10,)
  1. 使用 jax.vmap 创建 process_batch
  2. 在调用 process_batch 之前,断言 batch_input 的形状为 (BATCH_SIZE, 10)
  3. 在调用 process_batch 之后,断言 batch_output 的形状为 (BATCH_SIZE, 10)
BATCH_SIZE = 5
ITEM_SIZE = 10

def process_single_item_vmap(item: chex.Array) -> chex.Array:
  """处理单个项目,断言其形状。"""
  # TODO: 断言单个项目的形状是 (ITEM_SIZE,)
  chex.assert_shape(item, <TODO>)
  result = item * 2.0
  # TODO: 断言单个项目输出的形状是 (ITEM_SIZE,)
  chex.assert_shape(result, <TODO>)
  return result

# TODO: 使用 jax.vmap 对 process_single_item_vmap 函数进行向量化
process_batch = jax.vmap(<TODO>, in_axes=0, out_axes=0)

# 测试用例
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))

print("正在使用有效的批处理输入测试 vmap...")
# TODO: 在 vmap 之前断言完整 BATCHED 输入的形状
chex.assert_shape(valid_batch_input, <TODO>)

batch_output = process_batch(valid_batch_input)

# TODO: 在 vmap 之后断言完整 BATCHED 输出的形状
chex.assert_shape(batch_output, <TODO>)
print(f"Vmap 断言通过。输出形状: {batch_output.shape}\n")

print("正在使用批处理中无效的项目形状测试 vmap (来自 vmap 内部的错误)...")
try:
  # 这将在 vmap 映射的函数 'process_single_item_vmap' 内部失败
  process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
  print(f"捕获到预期的 vmap 错误 (来自内部函数):\n{e}\n")

print("正在使用无效的批处理形状测试 vmap (来自外部断言的错误)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
  # 这将在调用 process_batch 之前使断言失败
  chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # 此行将失败
  process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
  print(f"捕获到预期的 vmap 错误 (来自外部断言):\n{e}\n")

练习 2.2 解决方案

BATCH_SIZE = 5
ITEM_SIZE = 10

def process_single_item_vmap(item: chex.Array) -> chex.Array:
  """处理单个项目,断言其形状。"""
  # TODO: 断言单个项目的形状是 (ITEM_SIZE,)
  chex.assert_shape(item, (ITEM_SIZE,))
  result = item * 2.0
  # TODO: 断言单个项目输出的形状是 (ITEM_SIZE,)
  chex.assert_shape(result, (ITEM_SIZE,))
  return result

# TODO: 使用 jax.vmap 对函数进行向量化
process_batch = jax.vmap(process_single_item_vmap, in_axes=0, out_axes=0)

# 测试用例
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))

print("正在使用有效的批处理输入测试 vmap...")
# TODO: 在 vmap 之前断言完整 BATCHED 输入的形状
chex.assert_shape(valid_batch_input, (BATCH_SIZE, ITEM_SIZE))

batch_output = process_batch(valid_batch_input)

# TODO: 在 vmap 之后断言完整 BATCHED 输出的形状
chex.assert_shape(batch_output, (BATCH_SIZE, ITEM_SIZE))
print(f"Vmap 断言通过。输出形状: {batch_output.shape}\n")

print("正在使用批处理中无效的项目形状测试 vmap (来自 vmap 内部的错误)...")
try:
  # 这将在 vmap 映射的函数 'process_single_item_vmap' 内部失败
  process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
  print(f"捕获到预期的 vmap 错误 (来自内部函数):\n{e}\n")

print("正在使用无效的批处理形状测试 vmap (来自外部断言的错误)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
  # 这将在调用 process_batch 之前使断言失败
  chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # 此行将失败
  process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
  print(f"捕获到预期的 vmap 错误 (来自外部断言):\n{e}\n")

第 3 部分:Chex 与 Flax NNX

神经网络是复杂的,因此验证至关重要。Chex 自然地集成到 Flax NNX 模块中,通常在 call 方法内。

练习 3.1: NNX 模块中的输入/输出验证

完成 SimpleMLP 模块:
  • call 中,验证输入 x:
- 必须是二维的 ([batch, features])。 - 特征维度(轴 1)必须与 self.linear1.in_features 匹配。 - 类型必须是 jnp.float32
  • call 中,在返回之前验证输出 x:
- 必须是二维的。 - 特征维度(轴 1)必须与 self.linear2.out_features 匹配。
class SimpleMLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: chex.Array) -> chex.Array:
    # TODO: 验证输入 x
    # - 必须是二维的 ([batch, features])
    chex.assert_rank(<TODO>)
    # - 特征维度 (轴 1) 必须与 self.linear1.in_features 匹配
    chex.assert_axis_dimension(x, 1, <TODO>)
    # - 类型必须是 jnp.float32
    chex.assert_type(x, <TODO>)

    # 前向传播
    x = self.linear1(x)
    x = nnx.relu(x)
    x = self.linear2(x)

    # TODO: 在返回前验证输出 x
    # - 必须是二维的
    chex.assert_rank(<TODO>)
    # - 特征维度 (轴 1) 必须与 self.linear2.out_features 匹配
    chex.assert_axis_dimension(x, 1, self.linear2.out_features)

    return x

# SimpleMLP 的测试用例
key_nnx = nnx.Rngs(params=jax.random.key(0)) # 用于有状态操作的 NNX Rngs
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4

model = SimpleMLP(din, dmid, dout, rngs=key_nnx)

print("正在使用有效输入测试 NNX 模块:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O 检查通过。输出形状: {output_nnx.shape}\n")


print("正在使用无效输入秩测试 NNX 模块:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
  model(x_invalid_rank_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入秩):\n{e}\n")

print("正在使用无效输入特征维度测试 NNX 模块:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
  model(x_invalid_feat_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入特征):\n{e}\n")

print("正在使用无效输入类型测试 NNX 模块:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
  model(x_invalid_type_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入类型):\n{e}\n")

练习 3.1 解决方案

class SimpleMLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: chex.Array) -> chex.Array:
    # TODO: 验证输入 x
    # - 必须是二维的 ([batch, features])
    chex.assert_rank(x, 2)
    # - 特征维度 (轴 1) 必须与 self.linear1.in_features 匹配
    chex.assert_axis_dimension(x, 1, self.linear1.in_features)
    # - 类型必须是 jnp.float32
    chex.assert_type(x, jnp.float32)

    # 前向传播
    x = self.linear1(x)
    x = nnx.relu(x)
    x = self.linear2(x)

    # TODO: 在返回前验证输出 x
    # - 必须是二维的
    chex.assert_rank(x, 2)
    # - 特征维度 (轴 1) 必须与 self.linear2.out_features 匹配
    chex.assert_axis_dimension(x, 1, self.linear2.out_features)

    return x

# SimpleMLP 的测试用例
key_nnx = nnx.Rngs(params=jax.random.key(0)) # 用于有状态操作的 NNX Rngs
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4

model = SimpleMLP(din, dmid, dout, rngs=key_nnx)

print("正在使用有效输入测试 NNX 模块:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O 检查通过。输出形状: {output_nnx.shape}\n")


print("正在使用无效输入秩测试 NNX 模块:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
  model(x_invalid_rank_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入秩):\n{e}\n")

print("正在使用无效输入特征维度测试 NNX 模块:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
  model(x_invalid_feat_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入特征):\n{e}\n")

print("正在使用无效输入类型测试 NNX 模块:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
  model(x_invalid_type_nnx)
except AssertionError as e:
  print(f"捕获到预期的 NNX 错误 (无效输入类型):\n{e}\n")
自我反思:

在组合多个层或更改模型配置时,这些断言如何帮助您及早发现错误?它们充当层之间以及模型外部 API 的契约。

---

🏆 恭喜!

您已完成 Chex 练习。您现在应该对以下内容有了更好的了解:
  • 使用核心 Chex 断言来处理形状、类型、秩和 PyTree。
  • Chex 断言在 jax.jitjax.vmap 中的行为方式。
  • @chex.chexify 的用途和用法(及其注意事项)。
  • 使用 @chex.assert_max_traces 检测重新编译问题。
  • 将 Chex 断言集成到 Flax NNX 模块中以进行稳健的模型开发。

持续使用 Chex 可以显著提高 JAX 项目的可靠性和可维护性。

进一步探索(可选):
  • 探索在 Colab 环境之外使用 chex.chexify
  • 探索此处未涵盖的其他 Chex 断言(例如,chex.assert_devices_available)。
  • 如果您编写全面的测试套件,请研究 Chex 测试实用程序,如 @chex.variants
  • 考虑在典型训练循环中何时何地添加 Chex 断言。