Welcome! This notebook contains exercises to help you practice using Chex with JAX and Flax NNX, based on the concepts covered in the lecture.
Goal: Solidify your understanding of how Chex enhances reliability and debuggability in JAX-based projects. Instructions:TODO sections with your code.Let's get started!
# Run this cell first to install and import necessary libraries.
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import chex
import flax
from flax import nnx
import functools # For functools.partial
# Helper to reset trace counter for assert_max_traces exercises
def reset_trace_counter():
chex.clear_trace_counter()
# For some JAX versions, a small trick might be needed to fully reset
# internal JAX caches if you're re-running cells aggressively.
# This is usually not needed for these exercises if cells are run in order.
print(f"JAX version: {jax.__version__}")
print(f"Chex version: {chex.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Running on: {jax.default_backend()}")
chex.assertshape and chex.asserttypeprocess_data function below.
input_array has a shape of (3, None)input_array has a jnp.float32 dtype.output_array has a shape of (3, 1).def process_data_v1(input_array: chex.Array) -> chex.Array:
"""Processes an array, asserting shapes and types."""
# TODO: Assert input_array shape is (3, None)
chex.assert_shape(input_array, <TODO>)
# TODO: Assert input_array type is jnp.float32
chex.assert_type(<TODO>)
# Simulate some processing that reduces the last dimension to 1
output_array = input_array[:, :1] * 2.0
# TODO: Assert output_array shape is (3, 1)
chex.assert_shape(output_array, (3, 1))
return output_array
# Test cases
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("Testing with valid input...")
result = process_data_v1(valid_input)
print(f"Successfully processed valid input. Output shape: {result.shape}\n")
print("Testing with invalid shape input...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_v1(invalid_shape_input)
except AssertionError as e:
print(f"Caught expected error for invalid shape:\n{e}\n")
print("Testing with invalid type input...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
process_data_v1(invalid_type_input)
except AssertionError as e:
print(f"Caught expected error for invalid type: {e}\n")
def process_data_v1(input_array: chex.Array) -> chex.Array:
"""Processes an array, asserting shapes and types."""
# TODO: Assert input_array shape is (3, None)
chex.assert_shape(input_array, (3, None))
# TODO: Assert input_array type is jnp.float32
chex.assert_type(input_array, expected_types=jnp.float32)
# Simulate some processing that reduces the last dimension to 1
output_array = input_array[:, :1] * 2.0
# TODO: Assert output_array shape is (3, 1)
chex.assert_shape(output_array, (3, 1))
return output_array
# Test cases
key = jax.random.PRNGKey(0)
valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("Testing with valid input...")
result = process_data_v1(valid_input)
print(f"Successfully processed valid input. Output shape: {result.shape}\n")
print("Testing with invalid shape input...")
invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_v1(invalid_shape_input)
except AssertionError as e:
print(f"Caught expected error for invalid shape:\n{e}\n")
print("Testing with invalid type input...")
invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)
try:
process_data_v1(invalid_type_input)
except AssertionError as e:
print(f"Caught expected error for invalid type: {e}\n")
chex.assertrank and chex.assertscalarprocessdatav2 function.
matrix_input is a 2D array (rank 2).scalar_input is a scalar.result is also a 2D array.def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
"""Processes a matrix and a scalar."""
# TODO: Assert matrix_input has rank 2
chex.assert_rank(matrix_input, <TODO>)
# TODO: Assert scalar_input is a scalar
chex.assert_scalar(<TODO>)
result = matrix_input * scalar_input + 1.0
# TODO: Assert result has rank 2
chex.assert_rank(result, <TODO>)
return result
# Test cases
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))
print("Testing with valid rank/scalar inputs...")
try:
res_valid = process_data_v2(matrix, scalar)
print(f"Successfully processed valid rank/scalar. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught unexpected error for valid rank/scalar:\n{e}\n")
print("Testing with invalid rank input...")
try:
process_data_v2(not_a_matrix, scalar)
print(f"Successfully processed invalid rank. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught expected error for invalid rank:\n{e}\n")
print("Testing with non-scalar input...")
try:
process_data_v2(matrix, not_a_scalar)
print(f"Successfully processed non-scalar. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught expected error for non-scalar:\n{e}\n")
def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:
"""Processes a matrix and a scalar."""
# TODO: Assert matrix_input has rank 2
chex.assert_rank(matrix_input, expected_ranks=2)
# TODO: Assert scalar_input is a scalar
chex.assert_scalar(scalar_input)
result = matrix_input * scalar_input + 1.0
# TODO: Assert result has rank 2
chex.assert_rank(result, expected_ranks=2)
return result
# Test cases
matrix = jnp.ones((3, 4))
scalar = 5.0
not_a_scalar = jnp.array([5.0])
not_a_matrix = jnp.ones((3,4,1))
print("Testing with valid rank/scalar inputs...")
try:
res_valid = process_data_v2(matrix, scalar)
print(f"Successfully processed valid rank/scalar. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught unexpected error for valid rank/scalar:\n{e}\n")
print("Testing with invalid rank input...")
try:
process_data_v2(not_a_matrix, scalar)
print(f"Successfully processed invalid rank. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught expected error for invalid rank:\n{e}\n")
print("Testing with non-scalar input...")
try:
process_data_v2(matrix, not_a_scalar)
print(f"Successfully processed non-scalar. Result shape: {res_valid.shape}\n")
except AssertionError as e:
print(f"Caught expected error for non-scalar:\n{e}\n")
asserttreesallclose, asserttreeallfinite)def process_pytree(tree1, tree2):
"""
Checks if two PyTrees are close and if the first tree is finite.
Returns a new tree where elements are tree1 + tree2.
"""
# TODO: Assert tree1 and tree2 are (close to) equal. Use a small tolerance.
chex.assert_trees_all_close(<TODO> rtol=1e-5, atol=1e-8)
# TODO: Assert all elements in tree1 are finite (not NaN or Inf).
chex.assert_tree_all_finite(<TODO>)
# Perform some operation
return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)
# Test cases
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}
print("Testing with close and finite PyTrees...")
result_valid = process_pytree(tree_a, tree_b_close)
print("Successfully processed valid PyTrees.\n")
print("Testing with non-close PyTrees...")
try:
process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
print(f"Caught expected error for non-close trees:\n\n{e}\n")
print("Testing with non-finite PyTree...")
try:
process_pytree(tree_d_nan, tree_b_close) # tree_d_nan will be checked for finiteness
except AssertionError as e:
print(f"Caught expected error for non-finite tree:\n\n{e}\n")
def process_pytree(tree1, tree2):
"""
Checks if two PyTrees are close and if the first tree is finite.
Returns a new tree where elements are tree1 + tree2.
"""
# TODO: Assert tree1 and tree2 are (close to) equal. Use a small tolerance.
chex.assert_trees_all_close(tree1, tree2, rtol=1e-5, atol=1e-8)
# TODO: Assert all elements in tree1 are finite (not NaN or Inf).
chex.assert_tree_all_finite(tree1)
# Perform some operation
return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)
# Test cases
tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}
tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}
tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}
tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}
print("Testing with close and finite PyTrees...")
result_valid = process_pytree(tree_a, tree_b_close)
print("Successfully processed valid PyTrees.\n")
print("Testing with non-close PyTrees...")
try:
process_pytree(tree_a, tree_c_not_close)
except AssertionError as e:
print(f"Caught expected error for non-close trees:\n\n{e}\n")
print("Testing with non-finite PyTree...")
try:
process_pytree(tree_d_nan, tree_b_close) # tree_d_nan will be checked for finiteness
except AssertionError as e:
print(f"Caught expected error for non-finite tree:\n\n{e}\n")
jax.jit and jax.vmap.
@jax.jitprocessdatav1 function from Exercise 1.1.@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
"""JIT-compiled version of process_data_v1 with its Chex assertions."""
# (Assertions are inside process_data_v1, which we'll effectively re-use here)
# For clarity, let's re-define it with assertions directly here.
chex.assert_shape(input_array, (3, None))
chex.assert_type(input_array, jnp.float32)
output_array = input_array[:, :1] * 2.0
chex.assert_shape(output_array, (3, 1))
return output_array
# Test cases for JIT version
key = jax.random.PRNGKey(1) # Use a different key for potentially different values
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("Testing JITted function with valid input...")
# First call will compile
result_jit = process_data_jitted(<TODO>)
print(f"Successfully processed JITted valid input. Output shape: {result_jit.shape}")
# Second call uses cached compilation
result_jit_cached = process_data_jitted(<TODO> * 2)
print(f"Successfully processed JITted valid input (cached). Output shape: {result_jit_cached.shape}\n")
print("Testing JITted function with invalid shape input...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_jitted(<TODO>)
except AssertionError as e:
print(f"Caught expected JITted error for invalid shape:\n\n{e}\n")
@jax.jit
def process_data_jitted(input_array: chex.Array) -> chex.Array:
"""JIT-compiled version of process_data_v1 with its Chex assertions."""
# (Assertions are inside process_data_v1, which we'll effectively re-use here)
# For clarity, let's re-define it with assertions directly here.
chex.assert_shape(input_array, (3, None))
chex.assert_type(input_array, jnp.float32)
output_array = input_array[:, :1] * 2.0
chex.assert_shape(output_array, (3, 1))
return output_array
# Test cases for JIT version
key = jax.random.PRNGKey(1) # Use a different key for potentially different values
valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)
print("Testing JITted function with valid input...")
# First call will compile
result_jit = process_data_jitted(valid_input_jit)
print(f"Successfully processed JITted valid input. Output shape: {result_jit.shape}")
# Second call uses cached compilation
result_jit_cached = process_data_jitted(valid_input_jit * 2)
print(f"Successfully processed JITted valid input (cached). Output shape: {result_jit_cached.shape}\n")
print("Testing JITted function with invalid shape input...")
invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)
try:
process_data_jitted(invalid_shape_input_jit)
except AssertionError as e:
print(f"Caught expected JITted error for invalid shape:\n\n{e}\n")
Chex assertions work seamlessly within JITted functions, catching errors based on the concrete values passed during runtime, even though the checks are defined within the compiled code.
---
@jax.vmap(10,).
processsingleitem_vmap that processes one item.item has shape (10,).
- The function should double the item's values.
- Assert the result (output of processsingleitem_vmap) also has shape (10,).
jax.vmap to create process_batch.processbatch, assert the batchinput has shape (BATCH_SIZE, 10).processbatch, assert the batchoutput has shape (BATCH_SIZE, 10).BATCH_SIZE = 5
ITEM_SIZE = 10
def process_single_item_vmap(item: chex.Array) -> chex.Array:
"""Processes a single item, asserting its shape."""
# TODO: Assert shape of a SINGLE item is (ITEM_SIZE,)
chex.assert_shape(item, <TODO>)
result = item * 2.0
# TODO: Assert shape of single item output is (ITEM_SIZE,)
chex.assert_shape(result, <TODO>)
return result
# TODO: Vectorize the process_single_item_vmap function using jax.vmap
process_batch = jax.vmap(<TODO>, in_axes=0, out_axes=0)
# Test cases
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))
print("Testing vmap with valid batch input...")
# TODO: Assert shape of the full BATCHED input BEFORE vmap
chex.assert_shape(valid_batch_input, <TODO>)
batch_output = process_batch(valid_batch_input)
# TODO: Assert shape of the full BATCHED output AFTER vmap
chex.assert_shape(batch_output, <TODO>)
print(f"Vmap assertion passed. Output shape: {batch_output.shape}\n")
print("Testing vmap with invalid item shape in batch (error from inside vmap)...")
try:
# This will fail inside the vmapped function 'process_single_item_vmap'
process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
print(f"Caught expected vmap error (from inner function):\n{e}\n")
print("Testing vmap with invalid batch shape (error from outer assertion)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
# This will fail the assertion *before* calling process_batch
chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # This line will fail
process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
print(f"Caught expected vmap error (from outer assertion):\n{e}\n")
BATCH_SIZE = 5
ITEM_SIZE = 10
def process_single_item_vmap(item: chex.Array) -> chex.Array:
"""Processes a single item, asserting its shape."""
# TODO: Assert shape of a SINGLE item is (ITEM_SIZE,)
chex.assert_shape(item, (ITEM_SIZE,))
result = item * 2.0
# TODO: Assert shape of single item output is (ITEM_SIZE,)
chex.assert_shape(result, (ITEM_SIZE,))
return result
# TODO: Vectorize the function using jax.vmap
process_batch = jax.vmap(process_single_item_vmap, in_axes=0, out_axes=0)
# Test cases
key = jax.random.PRNGKey(2)
valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))
invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))
print("Testing vmap with valid batch input...")
# TODO: Assert shape of the full BATCHED input BEFORE vmap
chex.assert_shape(valid_batch_input, (BATCH_SIZE, ITEM_SIZE))
batch_output = process_batch(valid_batch_input)
# TODO: Assert shape of the full BATCHED output AFTER vmap
chex.assert_shape(batch_output, (BATCH_SIZE, ITEM_SIZE))
print(f"Vmap assertion passed. Output shape: {batch_output.shape}\n")
print("Testing vmap with invalid item shape in batch (error from inside vmap)...")
try:
# This will fail inside the vmapped function 'process_single_item_vmap'
process_batch(invalid_batch_input_item_shape)
except AssertionError as e:
print(f"Caught expected vmap error (from inner function):\n{e}\n")
print("Testing vmap with invalid batch shape (error from outer assertion)...")
invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))
try:
# This will fail the assertion *before* calling process_batch
chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # This line will fail
process_batch(invalid_batch_input_outer_shape)
except AssertionError as e:
print(f"Caught expected vmap error (from outer assertion):\n{e}\n")
call method.
SimpleMLP module:
call, validate the input x:
[batch, features]).
- The feature dimension (axis 1) must match self.linear1.in_features.
- Type must be jnp.float32.
call, validate the output x before returning:
self.linear2.out_features.
class SimpleMLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x: chex.Array) -> chex.Array:
# TODO: Validate input x
# - Must be 2D ([batch, features])
chex.assert_rank(<TODO>)
# - Feature dimension (axis 1) must match self.linear1.in_features
chex.assert_axis_dimension(x, 1, <TODO>)
# - Type must be jnp.float32
chex.assert_type(x, <TODO>)
# Forward pass
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
# TODO: Validate output x before returning
# - Must be 2D
chex.assert_rank(<TODO>)
# - Feature dimension (axis 1) must match self.linear2.out_features
chex.assert_axis_dimension(x, 1, self.linear2.out_features)
return x
# Test cases for SimpleMLP
key_nnx = nnx.Rngs(params=jax.random.key(0)) # NNX Rngs for stateful operations
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4
model = SimpleMLP(din, dmid, dout, rngs=key_nnx)
print("Testing NNX Module with valid input:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O Check passed. Output shape: {output_nnx.shape}\n")
print("Testing NNX Module with invalid input rank:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
model(x_invalid_rank_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input rank):\n{e}\n")
print("Testing NNX Module with invalid input feature dimension:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
model(x_invalid_feat_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input features):\n{e}\n")
print("Testing NNX Module with invalid input type:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
model(x_invalid_type_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input type):\n{e}\n")
class SimpleMLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x: chex.Array) -> chex.Array:
# TODO: Validate input x
# - Must be 2D ([batch, features])
chex.assert_rank(x, 2)
# - Feature dimension (axis 1) must match self.linear1.in_features
chex.assert_axis_dimension(x, 1, self.linear1.in_features)
# - Type must be jnp.float32
chex.assert_type(x, jnp.float32)
# Forward pass
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
# TODO: Validate output x before returning
# - Must be 2D
chex.assert_rank(x, 2)
# - Feature dimension (axis 1) must match self.linear2.out_features
chex.assert_axis_dimension(x, 1, self.linear2.out_features)
return x
# Test cases for SimpleMLP
key_nnx = nnx.Rngs(params=jax.random.key(0)) # NNX Rngs for stateful operations
din, dmid, dout = 10, 20, 5
batch_size_nnx = 4
model = SimpleMLP(din, dmid, dout, rngs=key_nnx)
print("Testing NNX Module with valid input:")
x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)
output_nnx = model(x_valid_nnx)
print(f"NNX I/O Check passed. Output shape: {output_nnx.shape}\n")
print("Testing NNX Module with invalid input rank:")
x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)
try:
model(x_invalid_rank_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input rank):\n{e}\n")
print("Testing NNX Module with invalid input feature dimension:")
x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)
try:
model(x_invalid_feat_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input features):\n{e}\n")
print("Testing NNX Module with invalid input type:")
x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)
try:
model(x_invalid_type_nnx)
except AssertionError as e:
print(f"Caught expected NNX error (invalid input type):\n{e}\n")
How would these assertions help you catch bugs early when composing multiple layers or changing model configurations? They act as contracts between layers and for the model's external API.
---
jax.jit and jax.vmap.@chex.chexify (and its caveats).@chex.assertmaxtraces.Using Chex consistently can significantly improve the reliability and maintainability of your JAX projects.
Further Exploration (Optional):chex.chexify outside of a Colab environment.chex.assertdevicesavailable).@chex.variants if you write comprehensive test suites.