欢迎!本笔记本包含练习,可帮助您练习讲座中讨论的 JAX 和 Flax NNX 调试技术。如果您是 PyTorch 用户,您会发现一些概念很熟悉,而其他概念则是 JAX 编译特有的。请记住首先运行设置单元格!
首先,让我们安装必要的库并导入它们。
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import flax
from flax import nnx
import chex
import pdb # Python 的内置调试器
import functools # 用于 functools.partial
import optax # 用于优化器,但我们不会进行深度训练
chex.set_n_cpu_devices(8) # 模拟一个有 8 个 CPU 的环境。这必须在任何 JAX 操作之前完成
print(f"模拟设备: {jax.devices()}")
# Flax v0.11+ 注意:flax.nnx.Optimizer API 已更改。
# 它现在在构造时需要一个 `wrt` 参数 (例如,wrt=nnx.Param)
# 并且更新调用现在是 `optimizer.update(model, grads)` 而不是 `optimizer.update(grads)`。
# 用于为可重复示例清除 chex 跟踪计数器的辅助函数
chex.clear_trace_counter()
print(f"JAX 版本: {jax.__version__}")
print(f"Flax 版本: {flax.__version__}") # NNX 是 flax 的一部分
print(f"Chex 版本: {chex.__version__}")
print(f"设备: {jax.devices()}")
JAX 的 JIT 编译意味着标准 Python 的 print() 在 JIT 编译的函数内部的行为有所不同。它在编译期间看到的是跟踪器,而不是运行时值。jax.debug.print() 是 JAX 感知的替代方案。
compute_and_print 函数中的 `# YOUR CODE HERE` 行。@jit
def compute_and_print(x):
y = x * 10
print("标准打印 (看到跟踪器):", y)
jax.debug.print("jax.debug.print (看到 y 的运行时值): {y_val}", y_val=y, ordered=True)
z = y / 2
# 练习 1.1: 在此处添加另一个 jax.debug.print 以查看 'z' 的运行时值
# 确保给它一个描述性消息并使用 ordered=True 参数。
# 你的代码在这里
return z
input_val = jnp.array(5.0)
print(f"输入值: {input_val}\n")
output_val = compute_and_print(input_val)
print(f"\n最终输出: {output_val}")
@jit
def compute_and_print_solution(x):
y = x * 10
print("标准打印 (看到跟踪器):", y)
jax.debug.print("jax.debug.print (看到 y 的运行时值): {y_val}", y_val=y, ordered=True)
z = y / 2
jax.debug.print("jax.debug.print (看到 z 的运行时值): {z_val}", z_val=z, ordered=True) # 解决方案
return z
input_val = jnp.array(5.0)
print(f"输入值: {input_val}\n")
output_val = compute_and_print_solution(input_val)
print(f"\n最终输出: {output_val}")
标准打印显示一个跟踪器对象(例如,Traced
interact_with_values 函数中的 `# YOUR CODE HERE` 行。@jit
def interact_with_values(x):
y = jnp.sin(x)
jax.debug.print("断点前 y 的值: {y_val}", y_val=y)
# 练习 2.1: 在此处放置断点。
# 你的代码在这里
z = jnp.cos(y)
jax.debug.print("断点后 z 的值: {z_val}", z_val=z)
return z
input_angle = jnp.array(0.75)
print("正在调用 interact_with_values...")
result = interact_with_values(input_angle)
print(f"结果: {result}")
@jit
def interact_with_values_solution(x):
y = jnp.sin(x)
jax.debug.print("断点前 y 的值: {y_val}", y_val=y)
jax.debug.breakpoint() # 解决方案
z = jnp.cos(y)
jax.debug.print("断点后 z 的值: {z_val}", z_val=z)
return z
input_angle = jnp.array(0.75)
print("正在调用 interact_with_values...")
result = interact_with_values_solution(input_angle)
print(f"结果: {result}")
complex_calculation 中,在指示的位置添加 `pdb.set_trace()`(# YOUR CODE HERE for 3.1)。@jit
def complex_calculation(x, threshold):
a = x * 2.0
b = jnp.log(a)
c = b + x
# 想象一下 'c' 有时会变成 NaN,而且很难看出原因。
# 我们想使用标准的 pdb 检查 'a'、'b' 和 'c'。
if c > threshold: # 这个条件在 JIT 下可能很棘手
# 练习 3.1: 在此处添加一个 pdb.set_trace()。
# 只有在此函数调用禁用 JIT 时它才会起作用。
# 你的代码在这里
print("在条件 pdb 跟踪内") # 如果 pdb 被命中,则会打印此内容
d = jnp.sqrt(jnp.abs(c)) # abs 以避免负数的 sqrt 产生 NaN
return d
value = jnp.array(0.1) # 尝试 0.1 然后 5.0
# 场景 1:启用 JIT (pdb.set_trace() 将被跳过或可能出错)
# print("--- 正在使用 JIT 运行 (pdb 可能会被跳过) ---")
# try:
# result_jit = complex_calculation(value, threshold=0.5)
# print(f"使用 JIT 的结果: {result_jit}")
# except Exception as e:
# print(f"场景 1 错误:
{e}
")
# 场景 2:禁用 JIT
print("\n--- 正在禁用此块的 JIT 运行 ---")
with jax.disable_jit():
# 练习 3.2: 在此处使用值和 threshold=0.5 调用 complex_calculation
# 以便触发您的 pdb.set_trace() (来自练习 3.1)。
# 你的代码在这里
pass # 删除此行
print("已完成 disable_jit 块。")
@jit
def complex_calculation_solution(x, threshold):
a = x * 2.0
b = jnp.log(a)
c = b + x
if c > threshold:
pdb.set_trace() # 解决方案 3.1
print("在条件 pdb 跟踪内")
d = jnp.sqrt(jnp.abs(c))
return d
value_for_pdb = jnp.array(5.0) # 此值将触发条件 c > threshold
# 场景 1:启用 JIT (pdb.set_trace() 将被跳过或可能出错)
print("--- 正在使用 JIT 运行 (pdb 可能会被跳过) ---")
try:
result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)
print(f"使用 JIT 的结果: {result_jit}")
except Exception as e
print(f"场景 1 错误:\n{e}\n")
print("\n--- 正在禁用此块的 JIT 运行 ---")
with jax.disable_jit():
result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # 解决方案 3.2
print(f"禁用 JIT 的结果: {result_no_jit}")
print("已完成 disable_jit 块。")
You'd use jax.disable_jit() when jax.debug.breakpoint() is insufficient, e.g., when you need the full pdb features (like stepping), want to use an IDE debugger, or when jax.debug.breakpoint() itself doesn't give enough context. The trade-off is performance loss.