Welcome! This notebook contains exercises to help you practice the JAX and Flax NNX debugging techniques covered in the lecture. If you’re a PyTorch user, some ideas will feel familiar while others are JAX-compilation specific. Remember to run the setup cell first!
Let’s start by installing and importing the needed libraries.
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import flax
from flax import nnx
import chex
import pdb # Python debugger
import functools # for functools.partial
import optax # used lightly for optimizer examples
chex.set_n_cpu_devices(8) # Simulate an 8-CPU environment; must run before other JAX ops
print(f"Simulated devices: {jax.devices()}")
# Flax v0.11+ note: flax.nnx.Optimizer API changed.
# It now requires `wrt` on construction (e.g., wrt=nnx.Param)
# and update is now `optimizer.update(model, grads)` not `optimizer.update(grads)`.
# Helper to clear chex counters for repeatable examples
chex.clear_trace_counter()
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}") # NNX ships with Flax
print(f"Chex version: {chex.__version__}")
print(f"Devices: {jax.devices()}")
Because JAX JIT compilation traces your code, plain Python print() inside a JITted function sees tracers, not runtime values. jax.debug.print() is the JAX-aware alternative.
# YOUR CODE HERE line in compute_and_print above.jax.debug.print() statement to show the runtime value of z.print(y) show?
- What do the jax.debug.print statements show for y and z? Why is it different?
@jit
def compute_and_print(x):
y = x * 10
print("Standard print (shows tracer):", y)
jax.debug.print("jax.debug.print (runtime y): {y_val}", y_val=y, ordered=True)
z = y / 2
# Exercise 1.1: add another jax.debug.print to inspect the runtime value of z
# Give it a descriptive message and use ordered=True.
# YOUR CODE HERE
return z
input_val = jnp.array(5.0)
print(f"Input value: {input_val}\n")
output_val = compute_and_print(input_val)
print(f"\nFinal output: {output_val}")
@jit
def compute_and_print_solution(x):
y = x * 10
print("Standard print (shows tracer):", y)
jax.debug.print("jax.debug.print (runtime y): {y_val}", y_val=y, ordered=True)
z = y / 2
jax.debug.print("jax.debug.print (runtime z): {z_val}", z_val=z, ordered=True) # Solution
return z
input_val = jnp.array(5.0)
print(f"Input value: {input_val}\n")
output_val = compute_and_print_solution(input_val)
print(f"\nFinal output: {output_val}")
The standard print shows a tracer object (e.g., Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>) because it runs during tracing. jax.debug.print shows concrete values (e.g., y = 50.0, z = 25.0) because it’s embedded in the compiled graph and executed with data.
jax.debug.breakpoint() is the JAX-friendly equivalent of pdb.set_trace() for transformed functions. It pauses execution and gives you a (jaxdb) prompt.
# YOUR CODE HERE line in interact_with_values.jax.debug.breakpoint() at the marked spot.(jaxdb) prompt:p y + Enter.
- Continue by typing c + Enter.
n or s are not available).@jit
def interact_with_values(x):
y = jnp.sin(x)
jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)
# Exercise 2.1: place a breakpoint here.
# YOUR CODE HERE
z = jnp.cos(y)
jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
return z
input_angle = jnp.array(0.75)
print("Calling interact_with_values...")
result = interact_with_values(input_angle)
print(f"Result: {result}")
@jit
def interact_with_values_solution(x):
y = jnp.sin(x)
jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)
jax.debug.breakpoint() # Solution
z = jnp.cos(y)
jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
return z
input_angle = jnp.array(0.75)
print("Calling interact_with_values...")
result = interact_with_values_solution(input_angle)
print(f"Result: {result}")
Sometimes you need full Python debugging tools. jax.disable_jit() lets a JAX function run eagerly.
complex_calculation, add pdb.set_trace() where indicated (# YOUR CODE HERE for 3.1).pdb.set_trace() inside a JITted function.with jax.disable_jit(): block, call complex_calculation with a value (try 0.1 then 5.0 to meet the condition) and threshold=0.5 (# YOUR CODE HERE for 3.2).c to continue.
jax.disable_jit() over jax.debug.breakpoint()?@jit
def complex_calculation(x, threshold):
a = x * 2.0
b = jnp.log(a)
c = b + x
# Imagine c occasionally becomes NaN and it’s hard to see why.
# We want to inspect a, b, and c with regular pdb.
if c > threshold: # This condition can be tricky under JIT
# Exercise 3.1: add pdb.set_trace() here.
# It will only fire if this function runs with JIT disabled.
# YOUR CODE HERE
print("Inside conditional pdb trace") # Printed only if pdb hits
d = jnp.sqrt(jnp.abs(c)) # abs avoids NaN from negative sqrt
return d
value = jnp.array(0.1) # Try 0.1 then 5.0
# Scenario 1: JIT enabled (pdb.set_trace() will be skipped or error)
# print("--- Running with JIT (pdb may be skipped) ---")
# try:
# result_jit = complex_calculation(value, threshold=0.5)
# print(f"Result with JIT: {result_jit}")
# except Exception as e:
# print(f"Scenario 1 error:\\n{e}\\n")
# Scenario 2: disable JIT
print("\n--- Running with JIT disabled for this block ---")
with jax.disable_jit():
# Exercise 3.2: call complex_calculation with value and threshold=0.5
# so your pdb.set_trace() (from 3.1) is triggered.
# YOUR CODE HERE
pass # Delete this line
print("disable_jit block finished.")
@jit
def complex_calculation_solution(x, threshold):
a = x * 2.0
b = jnp.log(a)
c = b + x
if c > threshold:
pdb.set_trace() # Solution 3.1
print("Inside conditional pdb trace")
d = jnp.sqrt(jnp.abs(c))
return d
value_for_pdb = jnp.array(5.0) # This value satisfies c > threshold
# Scenario 1: JIT enabled (pdb.set_trace() skipped or errors)
print("--- Running with JIT (pdb may be skipped) ---")
try:
result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)
print(f"Result with JIT: {result_jit}")
except Exception as e:
print(f"Scenario 1 error:\n{e}\n")
print("\n--- Running with JIT disabled for this block ---")
with jax.disable_jit():
result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # Solution 3.2
print(f"Result with JIT disabled: {result_no_jit}")
print("disable_jit block finished.")
Use jax.disable_jit() when jax.debug.breakpoint() isn’t enough—e.g., you need full pdb stepping, want an IDE debugger, or you’re not getting enough context from the JAX-aware breakpoint. The trade-off is performance.
jax_debug_nans FlagNaNs can be painful. jax_debug_nans helps JAX point to the exact operation that produced them.
problematic_function_for_nans that yields a NaN (e.g., x = jnp.array(-1.0), divisor = jnp.array(1.0) or x = jnp.array(1.0), divisor = jnp.array(0.0)). Run and observe the error.jax.config.update(\"jax_debug_nans\", True).def problematic_function_for_nans(x, divisor):
# Deliberately create NaNs when x < 0 or divisor == 0
return jnp.sqrt(x) / divisor
# Scenario 1: default behavior (no jax_debug_nans)
# print("--- Scenario 1: No jax_debug_nans ---")
# problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))
# problematic_function_for_nans(jnp.array(1.0), jnp.array(0.0))
# Scenario 2: enable jax_debug_nans
# print("--- Scenario 2: Enabling jax_debug_nans ---")
# jax.config.update("jax_debug_nans", True)
# problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))