Colab 笔记本:调试 JAX 和 Flax NNX - 练习

欢迎!本笔记本包含练习,可帮助您练习讲座中讨论的 JAX 和 Flax NNX 调试技术。如果您是 PyTorch 用户,您会发现一些概念很熟悉,而其他概念则是 JAX 编译特有的。请记住首先运行设置单元格!

首先,让我们安装必要的库并导入它们。

!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import flax
from flax import nnx
import chex
import pdb # Python 的内置调试器
import functools # 用于 functools.partial
import optax # 用于优化器,但我们不会进行深度训练

chex.set_n_cpu_devices(8) # 模拟一个有 8 个 CPU 的环境。这必须在任何 JAX 操作之前完成
print(f"模拟设备: {jax.devices()}")

# Flax v0.11+ 注意:flax.nnx.Optimizer API 已更改。
# 它现在在构造时需要一个 `wrt` 参数 (例如,wrt=nnx.Param)
# 并且更新调用现在是 `optimizer.update(model, grads)` 而不是 `optimizer.update(grads)`。

# 用于为可重复示例清除 chex 跟踪计数器的辅助函数
chex.clear_trace_counter()

print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}") # NNX 是 flax 的一部分
print(f"Chex 版本: {chex.__version__}")
print(f"设备: {jax.devices()}")

1. JAX 中的“打印调试”:jax.debug.print()

JAX 的 JIT 编译意味着标准 Python 的 print() 在 JIT 编译的函数内部的行为有所不同。它在编译期间看到的是跟踪器,而不是运行时值。jax.debug.print() 是 JAX 感知的替代方案。

练习 1.1:

  1. 取消注释并完成上面 compute_and_print 函数中的 `# YOUR CODE HERE` 行。
  2. 添加一个 `jax.debug.print()` 语句以显示 z 的运行时值。
  3. 运行单元格。观察输出。
- 标准的 `print(y)` 显示了什么? - `jax.debug.print` 语句为 y 和 z 显示了什么?为什么这会有所不同?
@jit
def compute_and_print(x):
  y = x * 10
  print("标准打印 (看到跟踪器):", y)
  jax.debug.print("jax.debug.print (看到 y 的运行时值): {y_val}", y_val=y, ordered=True)

  z = y / 2
  # 练习 1.1: 在此处添加另一个 jax.debug.print 以查看 'z' 的运行时值
  # 确保给它一个描述性消息并使用 ordered=True 参数。
  # 你的代码在这里

  return z

input_val = jnp.array(5.0)
print(f"输入值: {input_val}\n")
output_val = compute_and_print(input_val)
print(f"\n最终输出: {output_val}")

解决方案(练习 1.1,尝试后):

@jit
def compute_and_print_solution(x):
  y = x * 10
  print("标准打印 (看到跟踪器):", y)
  jax.debug.print("jax.debug.print (看到 y 的运行时值): {y_val}", y_val=y, ordered=True)

  z = y / 2
  jax.debug.print("jax.debug.print (看到 z 的运行时值): {z_val}", z_val=z, ordered=True) # 解决方案

  return z

input_val = jnp.array(5.0)
print(f"输入值: {input_val}\n")
output_val = compute_and_print_solution(input_val)
print(f"\n最终输出: {output_val}")

标准打印显示一个跟踪器对象(例如,Tracedwith)。这是因为它在 JAX 的跟踪阶段执行。jax.debug.print 显示具体的数值(例如,y 为 50.0,z 为 25.0),因为它被嵌入到编译的计算图中并在运行时使用数据执行。

2. JIT 中的交互式调试:jax.debug.breakpoint()

jax.debug.breakpoint() 是 JAX 中与 pdb.set_trace() 等效的函数,用于在转换后的函数内部使用。它会暂停执行并为您提供一个 (jaxdb) 提示符。

练习 2.1:

  1. 取消注释并完成上面 interact_with_values 函数中的 `# YOUR CODE HERE` 行。
  2. 在指定的位置添加 `jax.debug.breakpoint()`。
  3. 运行单元格。
  4. 当执行在 (jaxdb) 提示符处暂停时:
- 通过键入 `p y` 并按 Enter 键来检查 y 的值。 - 通过键入 `c` 并按 Enter 键来继续执行。
  1. 请注意,jaxdb 具有 pdb 命令的子集(例如,步进 n 或 s 不可用)。
@jit
def interact_with_values(x):
  y = jnp.sin(x)
  jax.debug.print("断点前 y 的值: {y_val}", y_val=y)

  # 练习 2.1: 在此处放置断点。
  # 你的代码在这里

  z = jnp.cos(y)
  jax.debug.print("断点后 z 的值: {z_val}", z_val=z)
  return z

input_angle = jnp.array(0.75)
print("正在调用 interact_with_values...")
result = interact_with_values(input_angle)
print(f"结果: {result}")

解决方案(练习 2.1,尝试后):

@jit
def interact_with_values_solution(x):
  y = jnp.sin(x)
  jax.debug.print("断点前 y 的值: {y_val}", y_val=y)

  jax.debug.breakpoint() # 解决方案

  z = jnp.cos(y)
  jax.debug.print("断点后 z 的值: {z_val}", z_val=z)
  return z

input_angle = jnp.array(0.75)
print("正在调用 interact_with_values...")
result = interact_with_values_solution(input_angle)
print(f"结果: {result}")

3. 回到基础:使用 jax.disable_jit() 临时禁用 JIT

有时,您需要标准 Python 调试工具的全部功能。`jax.disable_jit()` 允许 JAX 函数以 Eager 模式执行。

练习 3.1 和 3.2:

  1. complex_calculation 中,在指示的位置添加 `pdb.set_trace()`(# YOUR CODE HERE for 3.1)。
  2. 首先,尝试按原样运行单元格(取消注释场景 1 并注释掉场景 2 的调用)。观察在 JIT 编译的函数中使用 `pdb.set_trace()` 会发生什么。
  3. 然后,注释掉场景 1。
  4. 在场景 2 中,在 `with jax.disable_jit():` 块内,使用值(首先尝试 0.1,然后是 5.0 以确保满足条件)和 `threshold=0.5` 调用 `complex_calculation`(其中 # YOUR CODE HERE for 3.2)。
  5. 当 pdb 触发时:
- 检查 a、b 和 c。 - 键入 `c` 以继续。
  1. 思考:您会在什么时候使用 `jax.disable_jit()` 而不是 `jax.debug.breakpoint()`?
@jit
def complex_calculation(x, threshold):
  a = x * 2.0
  b = jnp.log(a)
  c = b + x
  # 想象一下 'c' 有时会变成 NaN,而且很难看出原因。
  # 我们想使用标准的 pdb 检查 'a'、'b' 和 'c'。
  if c > threshold: # 这个条件在 JIT 下可能很棘手
      # 练习 3.1: 在此处添加一个 pdb.set_trace()。
      # 只有在此函数调用禁用 JIT 时它才会起作用。
      # 你的代码在这里
      print("在条件 pdb 跟踪内") # 如果 pdb 被命中,则会打印此内容
  d = jnp.sqrt(jnp.abs(c)) # abs 以避免负数的 sqrt 产生 NaN
  return d

value = jnp.array(0.1) # 尝试 0.1 然后 5.0

# 场景 1:启用 JIT (pdb.set_trace() 将被跳过或可能出错)
# print("--- 正在使用 JIT 运行 (pdb 可能会被跳过) ---")
# try:
#   result_jit = complex_calculation(value, threshold=0.5)
#   print(f"使用 JIT 的结果: {result_jit}")
# except Exception as e:
#   print(f"场景 1 错误:
{e}
")

# 场景 2:禁用 JIT
print("\n--- 正在禁用此块的 JIT 运行 ---")
with jax.disable_jit():
  # 练习 3.2: 在此处使用值和 threshold=0.5 调用 complex_calculation
  # 以便触发您的 pdb.set_trace() (来自练习 3.1)。
  # 你的代码在这里
  pass # 删除此行

print("已完成 disable_jit 块。")

解决方案(练习 3.1 和 3.2,尝试后):

@jit
def complex_calculation_solution(x, threshold):
  a = x * 2.0
  b = jnp.log(a)
  c = b + x
  if c > threshold:
      pdb.set_trace() # 解决方案 3.1
      print("在条件 pdb 跟踪内")
  d = jnp.sqrt(jnp.abs(c))
  return d

value_for_pdb = jnp.array(5.0) # 此值将触发条件 c > threshold

# 场景 1:启用 JIT (pdb.set_trace() 将被跳过或可能出错)
print("--- 正在使用 JIT 运行 (pdb 可能会被跳过) ---")
try:
  result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)
  print(f"使用 JIT 的结果: {result_jit}")
except Exception as e
  print(f"场景 1 错误:\n{e}\n")

print("\n--- 正在禁用此块的 JIT 运行 ---")
with jax.disable_jit():
  result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # 解决方案 3.2
  print(f"禁用 JIT 的结果: {result_no_jit}")
print("已完成 disable_jit 块。")

You'd use jax.disable_jit() when jax.debug.breakpoint() is insufficient, e.g., when you need the full pdb features (like stepping), want to use an IDE debugger, or when jax.debug.breakpoint() itself doesn't give enough context. The trade-off is performance loss.

4. Automatic NaN Hunting: jaxdebugnans Flag

NaNs can be a nightmare. jaxdebugnans helps JAX pinpoint the exact operation causing them。

Exercise 4.1 & 4.2:

  1. In Scenario 1, uncomment the example call or create your own call to problematicfunctionfor_nans that results in a NaN (e.g., x = jnp.array(-1.0), divisor = jnp.array(1.0) or x = jnp.array(1.0), divisor = jnp.array(0.0)). Run and observe the error.
  2. In Scenario 2:
    • Uncomment the line jax.config.update(