Colab Notebook: Debugging JAX and Flax NNX - Exercises

Welcome! This notebook contains exercises to help you practice the JAX and Flax NNX debugging techniques covered in the lecture. If you’re a PyTorch user, some ideas will feel familiar while others are JAX-compilation specific. Remember to run the setup cell first!

Let’s start by installing and importing the needed libraries.

!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import flax
from flax import nnx
import chex
import pdb # Python debugger
import functools # for functools.partial
import optax # used lightly for optimizer examples

chex.set_n_cpu_devices(8) # Simulate an 8-CPU environment; must run before other JAX ops
print(f"Simulated devices: {jax.devices()}")

# Flax v0.11+ note: flax.nnx.Optimizer API changed.
# It now requires `wrt` on construction (e.g., wrt=nnx.Param)
# and update is now `optimizer.update(model, grads)` not `optimizer.update(grads)`.

# Helper to clear chex counters for repeatable examples
chex.clear_trace_counter()

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}") # NNX ships with Flax
print(f"Chex version: {chex.__version__}")
print(f"Devices: {jax.devices()}")

1. Print Debugging in JAX: jax.debug.print()

Because JAX JIT compilation traces your code, plain Python print() inside a JITted function sees tracers, not runtime values. jax.debug.print() is the JAX-aware alternative.

Exercise 1.1:

  1. Uncomment and complete the # YOUR CODE HERE line in compute_and_print above.
  2. Add a jax.debug.print() statement to show the runtime value of z.
  3. Run the cell and inspect the output.
- What does the standard print(y) show? - What do the jax.debug.print statements show for y and z? Why is it different?
@jit
def compute_and_print(x):
  y = x * 10
  print("Standard print (shows tracer):", y)
  jax.debug.print("jax.debug.print (runtime y): {y_val}", y_val=y, ordered=True)

  z = y / 2
  # Exercise 1.1: add another jax.debug.print to inspect the runtime value of z
  # Give it a descriptive message and use ordered=True.
  # YOUR CODE HERE

  return z

input_val = jnp.array(5.0)
print(f"Input value: {input_val}\n")
output_val = compute_and_print(input_val)
print(f"\nFinal output: {output_val}")

Solution (for Exercise 1.1, after you try):

@jit
def compute_and_print_solution(x):
  y = x * 10
  print("Standard print (shows tracer):", y)
  jax.debug.print("jax.debug.print (runtime y): {y_val}", y_val=y, ordered=True)

  z = y / 2
  jax.debug.print("jax.debug.print (runtime z): {z_val}", z_val=z, ordered=True) # Solution

  return z

input_val = jnp.array(5.0)
print(f"Input value: {input_val}\n")
output_val = compute_and_print_solution(input_val)
print(f"\nFinal output: {output_val}")

The standard print shows a tracer object (e.g., Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>) because it runs during tracing. jax.debug.print shows concrete values (e.g., y = 50.0, z = 25.0) because it’s embedded in the compiled graph and executed with data.

2. Interactive Debugging inside JIT: jax.debug.breakpoint()

jax.debug.breakpoint() is the JAX-friendly equivalent of pdb.set_trace() for transformed functions. It pauses execution and gives you a (jaxdb) prompt.

Exercise 2.1:

  1. Uncomment and fill the # YOUR CODE HERE line in interact_with_values.
  2. Add jax.debug.breakpoint() at the marked spot.
  3. Run the cell.
  4. When execution pauses at the (jaxdb) prompt:
- Inspect y by typing p y + Enter. - Continue by typing c + Enter.
  1. Note: jaxdb supports a subset of pdb commands (e.g., n or s are not available).
@jit
def interact_with_values(x):
  y = jnp.sin(x)
  jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)

  # Exercise 2.1: place a breakpoint here.
  # YOUR CODE HERE

  z = jnp.cos(y)
  jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
  return z

input_angle = jnp.array(0.75)
print("Calling interact_with_values...")
result = interact_with_values(input_angle)
print(f"Result: {result}")

Solution (for Exercise 2.1, after you try):

@jit
def interact_with_values_solution(x):
  y = jnp.sin(x)
  jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)

  jax.debug.breakpoint() # Solution

  z = jnp.cos(y)
  jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
  return z

input_angle = jnp.array(0.75)
print("Calling interact_with_values...")
result = interact_with_values_solution(input_angle)
print(f"Result: {result}")

3. Back to Basics: Temporarily Disable JIT with jax.disable_jit()

Sometimes you need full Python debugging tools. jax.disable_jit() lets a JAX function run eagerly.

Exercises 3.1 and 3.2:

  1. In complex_calculation, add pdb.set_trace() where indicated (# YOUR CODE HERE for 3.1).
  2. First, try running the cell as-is (uncomment Scenario 1 and comment out Scenario 2). See what happens when you use pdb.set_trace() inside a JITted function.
  3. Then comment Scenario 1.
  4. In Scenario 2, inside the with jax.disable_jit(): block, call complex_calculation with a value (try 0.1 then 5.0 to meet the condition) and threshold=0.5 (# YOUR CODE HERE for 3.2).
  5. When pdb triggers:
- Inspect a, b, and c. - Type c to continue.
  1. Reflection: when would you reach for jax.disable_jit() over jax.debug.breakpoint()?
@jit
def complex_calculation(x, threshold):
  a = x * 2.0
  b = jnp.log(a)
  c = b + x
  # Imagine c occasionally becomes NaN and it’s hard to see why.
  # We want to inspect a, b, and c with regular pdb.
  if c > threshold: # This condition can be tricky under JIT
      # Exercise 3.1: add pdb.set_trace() here.
      # It will only fire if this function runs with JIT disabled.
      # YOUR CODE HERE
      print("Inside conditional pdb trace") # Printed only if pdb hits
  d = jnp.sqrt(jnp.abs(c)) # abs avoids NaN from negative sqrt
  return d

value = jnp.array(0.1) # Try 0.1 then 5.0

# Scenario 1: JIT enabled (pdb.set_trace() will be skipped or error)
# print("--- Running with JIT (pdb may be skipped) ---")
# try:
#   result_jit = complex_calculation(value, threshold=0.5)
#   print(f"Result with JIT: {result_jit}")
# except Exception as e:
#   print(f"Scenario 1 error:\\n{e}\\n")

# Scenario 2: disable JIT
print("\n--- Running with JIT disabled for this block ---")
with jax.disable_jit():
  # Exercise 3.2: call complex_calculation with value and threshold=0.5
  # so your pdb.set_trace() (from 3.1) is triggered.
  # YOUR CODE HERE
  pass # Delete this line

print("disable_jit block finished.")

Solution (for Exercises 3.1 and 3.2, after you try):

@jit
def complex_calculation_solution(x, threshold):
  a = x * 2.0
  b = jnp.log(a)
  c = b + x
  if c > threshold:
      pdb.set_trace() # Solution 3.1
      print("Inside conditional pdb trace")
  d = jnp.sqrt(jnp.abs(c))
  return d

value_for_pdb = jnp.array(5.0) # This value satisfies c > threshold

# Scenario 1: JIT enabled (pdb.set_trace() skipped or errors)
print("--- Running with JIT (pdb may be skipped) ---")
try:
  result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)
  print(f"Result with JIT: {result_jit}")
except Exception as e:
  print(f"Scenario 1 error:\n{e}\n")

print("\n--- Running with JIT disabled for this block ---")
with jax.disable_jit():
  result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # Solution 3.2
  print(f"Result with JIT disabled: {result_no_jit}")
print("disable_jit block finished.")

Use jax.disable_jit() when jax.debug.breakpoint() isn’t enough—e.g., you need full pdb stepping, want an IDE debugger, or you’re not getting enough context from the JAX-aware breakpoint. The trade-off is performance.

4. Automatic NaN Hunting: jax_debug_nans Flag

NaNs can be painful. jax_debug_nans helps JAX point to the exact operation that produced them.

Exercises 4.1 and 4.2:

  1. In Scenario 1, uncomment the sample call (or make your own) to problematic_function_for_nans that yields a NaN (e.g., x = jnp.array(-1.0), divisor = jnp.array(1.0) or x = jnp.array(1.0), divisor = jnp.array(0.0)). Run and observe the error.
  2. In Scenario 2:
    • Uncomment jax.config.update(\"jax_debug_nans\", True).
    • Call the same function; note the clearer stack trace pointing to the offending op.
def problematic_function_for_nans(x, divisor):
  # Deliberately create NaNs when x < 0 or divisor == 0
  return jnp.sqrt(x) / divisor

# Scenario 1: default behavior (no jax_debug_nans)
# print("--- Scenario 1: No jax_debug_nans ---")
# problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))
# problematic_function_for_nans(jnp.array(1.0), jnp.array(0.0))

# Scenario 2: enable jax_debug_nans
# print("--- Scenario 2: Enabling jax_debug_nans ---")
# jax.config.update("jax_debug_nans", True)
# problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))