This notebook is designed to accompany the "Leveraging the JAX AI Stack" lecture. You'll get hands-on experience with core JAX concepts, Flax NNX for model building, Optax for optimization, and Orbax for checkpointing.
The exercises will guide you through implementing key components, drawing parallels to PyTorch where appropriate, to solidify your understanding.
Let's get started!
# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import orbax.checkpoint as ocp # For Orbax
from typing import Any, Dict, Tuple # For type hints
# Helper to print PyTrees more nicely for demonstration
import pprint
import os # For Orbax directory management
import shutil # For cleaning up Orbax directory
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Orbax version: {ocp.__version__}")
# Global JAX PRNG key for reproducibility in exercises
# Students can learn to split this key for different operations.
main_key = jax.random.key(0)
# Instructions for Exercise 1
key_ex1, main_key = jax.random.split(main_key) # Split the main key
# 1. Create JAX arrays a and b
# TODO: Create array 'a' (2x2 random normal) and 'b' (2x2 ones)
a = None # Placeholder
b = None # Placeholder
print("Array a:\n", a)
print("Array b:\n", b)
# 2. Perform element-wise addition
# TODO: Add a and b
c = None # Placeholder
print("Element-wise sum c = a + b:\n", c)
# 3. Perform matrix multiplication
# TODO: Matrix multiply a and b
d = None # Placeholder
print("Matrix product d = a @ b:\n", d)
# 4. Demonstrate immutability
# original_a_id = id(a)
# print(f"Original id(a): {original_a_id}")
# TODO: Perform an operation that reassigns 'a', e.g., a = a + 1
# a_new_ref = None # Placeholder
# new_a_id = id(a_new_ref)
# print(f"New id(a) after 'a = a + 1': {new_a_id}")
# TODO: Check if original_a_id is different from new_a_id
# print(f"IDs are different: {None}") # Placeholder
# @title Solution 1: JAX Core & NumPy API
key_ex1_sol, main_key = jax.random.split(main_key)
# 1. Create JAX arrays a and b
a_sol = jax.random.normal(key_ex1_sol, (2, 2))
b_sol = jnp.ones((2, 2))
print("Array a:\n", a_sol)
print("Array b:\n", b_sol)
# 2. Perform element-wise addition
c_sol = a_sol + b_sol
print("Element-wise sum c = a + b:\n", c_sol)
# 3. Perform matrix multiplication
d_sol = jnp.dot(a_sol, b_sol) # or d = a @ b
print("Matrix product d = a @ b:\n", d_sol)
# 4. Demonstrate immutability
original_a_id_sol = id(a_sol)
print(f"Original id(a_sol): {original_a_id_sol}")
a_sol_new_ref = a_sol + 1 # This creates a new array and rebinds the Python variable.
new_a_id_sol = id(a_sol_new_ref)
print(f"New id(a_sol_new_ref) after 'a_sol = a_sol + 1': {new_a_id_sol}")
print(f"IDs are different: {original_a_id_sol != new_a_id_sol}")
print("This shows that the original array was not modified in-place; a new array was created.")
%timeit magic command in Colab (in separate cells) to compare their execution speeds. Remember that the first call to a JIT-compiled function includes compilation time.# Instructions for Exercise 2
key_ex2_main, main_key = jax.random.split(main_key)
key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)
# 1. Define the Python function
def compute_heavy_stuff(x, w, b):
# TODO: Implement the operations
y1 = None # Placeholder
y2 = None # Placeholder
y3 = None # Placeholder
result = None # Placeholder
return result
# 2. Create a JIT-compiled version
# TODO: Use jax.jit to compile compute_heavy_stuff
fast_compute_heavy_stuff = None # Placeholder
# 3. Create dummy data
dim1, dim2, dim3 = 500, 1000, 500
x_data = jax.random.normal(key_ex2_x, (dim1, dim2))
w_data = jax.random.normal(key_ex2_w, (dim2, dim3))
b_data = jax.random.normal(key_ex2_b, (dim3,))
# 4. Call both functions
result_original = None # Placeholder compute_heavy_stuff(x_data, w_data, b_data)
result_fast_first_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # First call (compiles)
result_fast_second_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # Second call (uses compiled)
print(f"Result (original): {result_original}")
print(f"Result (fast, 1st call): {result_fast_first_call}")
print(f"Result (fast, 2nd call): {result_fast_second_call}")
# if result_original is not None and result_fast_first_call is not None:
# assert jnp.allclose(result_original, result_fast_first_call), "Results should match!"
# print("
Results from original and JIT-compiled functions match.")
# 5. Optional: Timing (use %timeit in separate cells for accuracy)
# print("
To see the speed difference, run these in separate cells:")
# print("%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# print("%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# @title Solution 2: `jax.jit` (Just-In-Time Compilation)
key_ex2_sol_main, main_key = jax.random.split(main_key)
key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)
# 1. Define the Python function
def compute_heavy_stuff_sol(x, w, b):
y = jnp.dot(x, w)
y = y + b
y = jnp.tanh(y)
result = jnp.sum(y)
return result
# 2. Create a JIT-compiled version
fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)
# 3. Create dummy data
dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500
x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))
w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))
b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))
# 4. Call both functions
# Call original once to ensure it's not timed with any JAX overhead if it were the first JAX op
result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# First call to JITed function includes compilation time
result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# Subsequent calls use the cached compiled code
result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
print(f"Result (original): {result_original_sol}")
print(f"Result (fast, 1st call): {result_fast_sol_first_call}")
print(f"Result (fast, 2nd call): {result_fast_sol_second_call}")
assert jnp.allclose(result_original_sol, result_fast_sol_first_call), "Results should match!"
print("\nResults from original and JIT-compiled functions match.")
# 5. Optional: Timing
# To accurately measure, run these in separate Colab cells:
# Cell 1:
# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# Cell 2:
# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# You should observe that the JIT-compiled version is significantly faster after the initial compilation.
print("\nTo see the speed difference, run the %timeit commands (provided in comments above) in separate cells.")
# Instructions for Exercise 3
# 1. Define the scalar_loss function
def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
# TODO: Implement the prediction and loss calculation
y_pred = None # Placeholder
loss = None # Placeholder
return loss
# 2. Create the gradient function using jax.grad
# TODO: Gradient of scalar_loss w.r.t. 'params' (argnums=0)
compute_gradients = None # Placeholder
# 3. Initialize dummy data
params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([7.0, 9.0, 11.0]) # Targets for y = 3x + 4 (to make non-zero loss with init_params)
# 4. Call the gradient function
gradients = None # Placeholder compute_gradients(params_init, x_input_data, y_target_data)
print("Initial params:", params_init)
print("Gradients w.r.t params:\n", gradients)
# Expected gradients (manual calculation for y_pred = wx+b, loss = mean((y_pred - y_true)^2)):
# dL/dw = mean(2 * (wx+b - y_true) * x)
# dL/db = mean(2 * (wx+b - y_true) * 1)
# For params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]
# x=1: y_pred = 2*1+1 = 3. Error = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8. dL/db_i_term = 2*(-4)*1 = -8
# x=2: y_pred = 2*2+1 = 5. Error = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8
# x=3: y_pred = 2*3+1 = 7. Error = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8
# Mean gradients: dL/dw = (-8-16-24)/3 = -48/3 = -16. dL/db = (-8-8-8)/3 = -24/3 = -8.
# if gradients is not None:
# assert jnp.isclose(gradients['w'], -16.0)
# assert jnp.isclose(gradients['b'], -8.0)
# print("
Gradients match expected values.")
# @title Solution 3: `jax.grad` (Automatic Differentiation)
# 1. Define the scalar_loss function
def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
y_pred = params['w'] * x + params['b']
loss = jnp.mean((y_pred - y_true)**2)
return loss
# 2. Create the gradient function using jax.grad
# Gradient of scalar_loss w.r.t. 'params' (which is the 0-th argument)
compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)
# 3. Initialize dummy data
params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data_sol = jnp.array([1.0, 2.0, 3.0])
y_target_data_sol = jnp.array([7.0, 9.0, 11.0])
# 4. Call the gradient function
gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)
print("Initial params:", params_init_sol)
print("Gradients w.r.t params:\n", pprint.pformat(gradients_sol))
# Verify with expected values (calculated in instructions)
expected_dL_dw = -16.0
expected_dL_db = -8.0
assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f"Grad w.r.t 'w' is {gradients_sol['w']}, expected {expected_dL_dw}"
assert jnp.isclose(gradients_sol['b'], expected_dL_db), f"Grad w.r.t 'b' is {gradients_sol['b']}, expected {expected_dL_db}"
print("\nGradients match expected values.")
# Instructions for Exercise 4
key_ex4_main, main_key = jax.random.split(main_key)
key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)
# 1. Define apply_affine for a single vector
def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
# TODO: Compute jnp.dot(matrix, vector) + bias
result = None # Placeholder
return result
# 2. Prepare data
batch_size = 4
input_features = 3
output_features = 2
# batch_of_vectors: (batch_size, input_features)
# single_matrix: (output_features, input_features)
# single_bias: (output_features,)
batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))
single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))
single_bias = jax.random.normal(key_ex4_bias, (output_features,))
# 3. Use jax.vmap to create batched_apply_affine
# TODO: Specify in_axes correctly: vector is batched, matrix and bias are not. out_axes should be 0.
batched_apply_affine = None # Placeholder jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)
# 4. Test batched_apply_affine
result_vmap = None # Placeholder batched_apply_affine(batch_of_vectors, single_matrix, single_bias)
print("Batch of vectors shape:", batch_of_vectors.shape)
print("Single matrix shape:", single_matrix.shape)
print("Single bias shape:", single_bias.shape)
if result_vmap is not None:
print("Result using vmap shape:", result_vmap.shape) # Expected: (batch_size, output_features)
# For comparison, a manual loop (less efficient):
# manual_results = []
# for i in range(batch_size):
# manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))
# result_manual_loop = jnp.stack(manual_results)
# assert jnp.allclose(result_vmap, result_manual_loop)
# print("vmap result matches manual loop result.")
else:
print("result_vmap is None.")
# @title Solution 4: `jax.vmap` (Automatic Vectorization)
key_ex4_sol_main, main_key = jax.random.split(main_key)
key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)
# 1. Define apply_affine for a single vector
def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
return jnp.dot(matrix, vector) + bias
# 2. Prepare data
batch_size_sol = 4
input_features_sol = 3
output_features_sol = 2
batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))
single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))
single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))
# 3. Use jax.vmap to create batched_apply_affine
# Vector is batched along axis 0, matrix and bias are not batched (broadcasted).
# out_axes=0 means the output will also be batched along its first axis.
batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)
# 4. Test batched_apply_affine
result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)
print("Batch of vectors shape:", batch_of_vectors_sol.shape)
print("Single matrix shape:", single_matrix_sol.shape)
print("Single bias shape:", single_bias_sol.shape)
print("Result using vmap shape:", result_vmap_sol.shape) # Expected: (batch_size, output_features)
assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)
# For comparison, a manual loop (less efficient):
manual_results_sol = []
for i in range(batch_size_sol):
manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))
result_manual_loop_sol = jnp.stack(manual_results_sol)
assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)
print("\nvmap result matches manual loop result, demonstrating correct vectorization.")
# Instructions for Exercise 5
key_ex5_model_init, main_key = jax.random.split(main_key)
# 1. & 2. & 3. Define the SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# TODO: Define an nnx.Linear layer named 'dense_layer'
# self.dense_layer = nnx.Linear(...)
self.some_attribute = None # Placeholder, remove later
pass # Remove this placeholder if class is not empty
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# TODO: Pass input x through the dense_layer
# return self.dense_layer(x)
return x # Placeholder
# 4. Instantiate the model
model_din = 3
model_dout = 2
# TODO: Create nnx.Rngs for parameter initialization. Use 'params' as the key name.
model_rngs = None # Placeholder nnx.Rngs(params=key_ex5_model_init)
my_model = None # Placeholder SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)
# 5. Test with dummy data
dummy_batch_size = 4
dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))
model_output = None # Placeholder
if my_model is not None:
model_output = my_model(dummy_input_ex5)
print(f"Model output shape: {model_output.shape}")
print(f"Model output:\n{model_output}")
model_state = my_model.get_state()
print(f"\nModel state (parameters, etc.):")
pprint.pprint(model_state)
else:
print("my_model is None.")
# @title Solution 5: Flax NNX - Defining a Model
key_ex5_sol_model_init, main_key = jax.random.split(main_key)
# 1. & 2. & 3. Define the SimpleNNXModel
class SimpleNNXModel_Sol(nnx.Module): # Renamed for solution cell
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# nnx.Linear will use the 'params' key from rngs by default for its parameters
self.dense_layer = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.dense_layer(x)
# 4. Instantiate the model
model_din_sol = 3
model_dout_sol = 2
# Create nnx.Rngs for parameter initialization.
# 'params' is the default key nnx.Linear looks for in the rngs object.
model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)
my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)
# 5. Test with dummy data
dummy_batch_size_sol = 4
dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))
model_output_sol = my_model_sol(dummy_input_ex5_sol)
print(f"Model output shape: {model_output_sol.shape}")
print(f"Model output:\n{model_output_sol}")
# model_state_sol = my_model_sol.get_state()
_, model_state_sol = nnx.split(my_model_sol)
print(f"\nModel state (parameters, etc.):")
nnx.display(model_state_sol)
# Check that parameters are present
assert 'dense_layer' in model_state_sol, "Key 'dense_layer' not in model_state"
assert 'kernel' in model_state_sol['dense_layer'], "Key 'kernel' not in model_state['dense_layer']"
assert 'bias' in model_state_sol['dense_layer'], "Key 'bias' not in model_state['dense_layer']"
print("\nModel parameters (kernel and bias for dense_layer) are present in the state.")
# Instructions for Exercise 6
# 1. Assume my_model_sol is available from Exercise 5 solution
# (If running standalone, re-instantiate it)
if 'my_model_sol' not in globals():
print("Re-initializing model from Ex5 solution for Ex6.")
key_ex6_model_init, main_key = jax.random.split(main_key)
_model_din_ex6 = 3
_model_dout_ex6 = 2
_model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)
# Use solution class name if defined, otherwise student's class name
_ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)
print("Model for optimizer created.")
else:
model_for_opt = my_model_sol # Use the one from previous solution
print("Using model 'my_model_sol' from previous exercise for 'model_for_opt'.")
# 2. Create an Optax optimizer
learning_rate = 0.001
# TODO: Create an optax.adam optimizer transform
optax_tx = None # Placeholder optax.adam(...)
# 3. Create an nnx.Optimizer wrapper
# TODO: Wrap the model (model_for_opt) and the optax transform (optax_tx)
# The `wrt` argument is now required to specify what to differentiate with respect to.
nnx_optimizer = None # Placeholder nnx.Optimizer(...)
# 4. Print the optimizer and its state
print("\nFlax NNX Optimizer wrapper:")
nnx.display(nnx_optimizer)
print("\nInitial Optimizer State (Optax state, e.g., Adam's momentum):")
if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):
pprint.pprint(nnx_optimizer.state)
# if hasattr(nnx_optimizer, 'opt_state'):
# adam_state = nnx_optimizer.opt_state
# assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')
# print("
Optimizer state structure looks plausible for Adam.")
else:
print("nnx_optimizer or its state is None or not structured as expected.")
# @title Solution 6: Optax & Flax NNX - Creating an Optimizer
# 1. Use my_model_sol from Exercise 5 solution
# If not run sequentially, ensure my_model_sol is defined:
if 'my_model_sol' not in globals():
print("Re-initializing model from Ex5 solution for Ex6.")
key_ex6_sol_model_init, main_key = jax.random.split(main_key)
_model_din_sol_ex6 = 3
_model_dout_sol_ex6 = 2
_model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)
# Ensure SimpleNNXModel_Sol is used
my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)
print("Model for optimizer re-created as 'my_model_sol'.")
else
print("Using model 'my_model_sol' from previous exercise.")
# 2. Create an Optax optimizer
learning_rate_sol = 0.001
# Create an optax.adam optimizer transform
optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)
# 3. Create an nnx.Optimizer wrapper
# This links the model and the Optax optimizer.
# The optimizer state will be initialized based on the model's parameters.
nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)
# 4. Print the optimizer and its state
print("\nFlax NNX Optimizer wrapper:")
nnx.display(nnx_optimizer_sol) # Shows the model it's associated with and the Optax transform
print("\nInitial Optimizer State (Optax state, e.g., Adam's momentum):")
# nnx.Optimizer stores the actual Optax state in its .opt_state attribute.
# This state is a PyTree that matches the structure of the model's parameters.
pprint.pprint(nnx_optimizer_sol.opt_state)
# Verify the structure of the optimizer state for Adam (count, mu, nu for each param)
assert hasattr(nnx_optimizer_sol, 'opt_state'), "Optax opt_state not found in nnx.Optimizer"
# The opt_state is a tuple, typically (CountState(), ScaleByAdamState()) for adam
adam_optax_internal_state = nnx_optimizer_sol.opt_state
assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), "Adam 'count' state not found."
# The second element of the tuple is often where parameter-specific states like mu and nu reside
if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):
param_specific_state = adam_optax_internal_state[1]
assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], "Adam 'mu' state for kernel not found."
print("\nOptimizer state structure looks correct for Adam.")
else
print("\nWarning: Optimizer state structure for Adam might be different or not fully verified.")
# Instructions for Exercise 7
key_ex7_main, main_key = jax.random.split(main_key)
key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)
# 1. Use model and optimizer from previous exercises' solutions
# Ensure my_model_sol and nnx_optimizer_sol are available
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
print("Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7.")
key_ex7_model_fallback, main_key = jax.random.split(main_key)
_model_din_ex7 = 3
_model_dout_ex7 = 2
_model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)
# Ensure SimpleNNXModel_Sol is used
my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)
_optax_tx_ex7 = optax.adam(learning_rate=0.001)
nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)
print("Model and optimizer re-created for Ex7.")
else
my_model_ex7 = my_model_sol
nnx_optimizer_ex7 = nnx_optimizer_sol
print("Using 'my_model_sol' and 'nnx_optimizer_sol' for 'my_model_ex7' and 'nnx_optimizer_ex7'.")
# 2. & 3. Define the train_step function
# TODO: Decorate with @nnx.jit
# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Type hint with base nnx.Module
# x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# TODO: Define inner loss_fn_for_grad(current_model_state_for_grad_fn)
# def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Type hint with base nnx.Module
# y_pred = model_in_grad_fn(x_batch)
# loss = jnp.mean((y_pred - y_batch)**2)
# return loss
# return jnp.array(0.0) # Placeholder
# TODO: Compute loss value and gradients using nnx.value_and_grad
# loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # Pass model_arg
# TODO: Update the optimizer (which updates the model_arg in-place)
# optimizer_arg.update(model_arg, grads)
# return loss_value
# return jnp.array(0.0) # Placeholder defined train_step function
# For the student to define:
# Make sure the function signature is correct for nnx.jit
@nnx.jit
def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# Placeholder implementation for student
def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
# y_pred = model_in_grad_fn(x_batch)
# loss = jnp.mean((y_pred - y_batch)**2)
# return loss
return jnp.array(0.0) # Student TODO: replace this
# loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
# optimizer_arg.update(grads)
# return loss_value
return jnp.array(-1.0) # Student TODO: replace this
# 4. Create dummy data
batch_s = 8
# Access features_in and features_out carefully
_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3
_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2
x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))
y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))
# Optional: Store initial param value for comparison
initial_kernel_val = None
if hasattr(my_model_ex7, 'get_state'):
_current_model_state_ex7 = my_model_ex7.get_state()
if 'dense_layer' in _current_model_state_ex7
initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()
print(f"Initial kernel value (sample): {initial_kernel_val}")
# 5. Call the train_step
# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # Student will uncomment
loss_after_step = jnp.array(-1.0) # Placeholder until student implements train_step
if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # Check if student implemented
loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)
print(f"Loss after one training step: {loss_after_step}")
else
print("Student needs to implement `train_step` function.")
# # 6. Optional: Verify parameter change
# updated_kernel_val_sol = None
# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update
# if 'dense_layer' in updated_model_state_sol:
# updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
# print(f"Updated kernel value (sample): {updated_kernel_val_sol}")
# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
# assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "Kernel parameter did not change!"
# print("Kernel parameter changed as expected after the training step.")
# else:
# print("Could not verify kernel change (initial or updated value was None).")
# @title Solution 7: Training Step with Flax NNX and Optax
key_ex7_sol_main, main_key = jax.random.split(main_key)
key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)
# 1. Use model and optimizer from previous exercises' solutions
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
print("Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7 solution.")
key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)
_model_din_sol_ex7 = 3
_model_dout_sol_ex7 = 2
_model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)
# Ensure SimpleNNXModel_Sol is used for the solution
my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)
_optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)
nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)
print("Model and optimizer re-created for Ex7 solution.")
else
# If solutions are run sequentially, these will be the correct instances
my_model_sol_ex7 = my_model_sol
nnx_optimizer_sol_ex7 = nnx_optimizer_sol
print("Using 'my_model_sol' and 'nnx_optimizer_sol' for Ex7 solution.")
# 2. & 3. Define the train_step function
@nnx.jit # Decorate with @nnx.jit for JIT compilation
def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
# Define inner loss_fn_for_grad. It takes the model as its first argument.
# It captures x_batch and y_batch from the outer scope.
def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
y_pred = model_in_grad_fn(x_batch) # Use the model passed to this inner function
loss = jnp.mean((y_pred - y_batch)**2)
return loss
# Compute loss value and gradients using nnx.value_and_grad.
# This will differentiate loss_fn_for_grad with respect to its first argument (model_in_grad_fn).
# We pass the current state of our model (model_arg) to it.
loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
# Update the optimizer. This updates the model_arg (which nnx_optimizer_sol_ex7 references) in-place.
optimizer_arg.update(model_arg, grads)
return loss_value
# 4. Create dummy data
batch_s_sol = 8
# Ensure din and dout match the model instantiation from Ex5/Ex6
# my_model_sol_ex7.dense_layer is an nnx.Linear object
din_from_model_sol = my_model_sol_ex7.dense_layer.in_features
dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features
x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))
y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))
# Optional: Store initial param value for comparison
initial_kernel_val_sol = None
_, current_model_state_sol = nnx.split(my_model_sol_ex7)
if 'dense_layer' in current_model_state_sol
initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()
print(f"Initial kernel value (sample): {initial_kernel_val_sol}")
# 5. Call the train_step
# First call will JIT compile the train_step_sol function.
loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"Loss after one training step (1st call, JIT): {loss_after_step_sol}")
# Second call to show it's faster (though %timeit is better for measurement)
loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"Loss after one training step (2nd call, cached): {loss_after_step_sol_2}")
# 6. Optional: Verify parameter change
updated_kernel_val_sol = None
_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update
if 'dense_layer' in updated_model_state_sol
updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
print(f"Updated kernel value (sample): {updated_kernel_val_sol}")
if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None
assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "Kernel parameter did not change!"
print("Kernel parameter changed as expected after the training step.")
else
print("Could not verify kernel change (initial or updated value was None).")
# Instructions for Exercise 8
# import orbax.checkpoint as ocp # Already imported
# import os, shutil # Already imported
# 1. Use model and optimizer from previous exercise solution
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
print("Re-initializing model and optimizer from Ex7 solution for Ex8.")
key_ex8_model_fallback, main_key = jax.random.split(main_key)
_model_din_ex8 = 3
_model_dout_ex8 = 2
_model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)
_ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)
_optax_tx_ex8 = optax.adam(learning_rate=0.001)
optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)
print("Model and optimizer re-created for Ex8.")
else
model_to_save = my_model_sol_ex7
optimizer_to_save = nnx_optimizer_sol_ex7
print("Using model and optimizer from Ex7 solution for Ex8.")
# 2. Define checkpoint directory
# TODO: Define checkpoint_dir
checkpoint_dir = None # Placeholder e.g., "/tmp/my_nnx_checkpoint_exercise/"
# if checkpoint_dir and os.path.exists(checkpoint_dir):
# shutil.rmtree(checkpoint_dir) # Clean up previous runs for safety
# if checkpoint_dir:
# os.makedirs(checkpoint_dir, exist_ok=True)
# 3. Create Orbax CheckpointManager
# TODO: Create options and manager
# options = ocp.CheckpointManagerOptions(...)
# mngr = ocp.CheckpointManager(...)
options = None
mngr = None
# 4. Bundle states
# current_step = 100 # Example step
# TODO: Get model_state and optimizer_state
# model_state_to_save = nnx.split(model_to_save)
# The optimizer state is now accessed via the .state attribute.
# opt_state_to_save = optimizer_to_save.state
# save_bundle = {
# 'model': model_state_to_save,
# 'optimizer': opt_state_to_save,
# 'step': current_step
# }
save_bundle = None
# 5. Save the checkpoint
# if mngr and save_bundle:
# TODO: Save checkpoint
# mngr.save(...)
# mngr.wait_until_finished()
# print(f"Checkpoint saved at step {current_step} to {checkpoint_dir}")
# else:
# print("Checkpoint manager or save_bundle not initialized.")
# --- Restoration ---
# 6.a Create new model and Optax transform (for restoration)
# key_ex8_restore_model, main_key = jax.random.split(main_key)
# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3
# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2
# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
# restored_model = _ModelClassRestore(
# din=din_restore, dout=dout_restore,
# rngs=nnx.Rngs(params=key_ex8_restore_model) # New key for different initial params
# )
# restored_optax_tx = optax.adam(learning_rate=0.001) # Same Optax config
restored_model = None
restored_optax_tx = None
# 6.b Restore the checkpoint
# loaded_bundle = None
# if mngr:
# TODO: Restore checkpoint
# latest_step = mngr.latest_step()
# if latest_step is not None:
# loaded_bundle = mngr.restore(...)
# print(f"Checkpoint restored from step {latest_step}")
# else:
# print("No checkpoint found to restore.")
# else:
# print("Checkpoint manager not initialized for restore.")
# 6.c Apply loaded states
# if loaded_bundle and restored_model:
# TODO: Update restored_model state
# nnx.update(restored_model, ...)
# print("Restored model state applied.")
# TODO: Create new nnx.Optimizer and assign its state
# restored_optimizer = nnx.Optimizer(...)
# restored_optimizer.state = ...
# print("Restored optimizer state applied.")
# else:
# print("Loaded_bundle or restored_model is None, cannot apply states.")
restored_optimizer = None
# 7. Verify restoration
# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
# _, restored_model_state = nnx.split(restored_model_sol)
# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value),
# "Model kernel parameters differ after restoration!"
# print("
Model parameters successfully restored and verified (kernel match).")
# # Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)
# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel),
# "Optimizer Adam mu for kernel differs!"
# print("Optimizer state (sample mu) successfully restored and verified.")
# 8. Clean up
# if mngr:
# mngr.close()
# if checkpoint_dir and os.path.exists(checkpoint_dir):
# shutil.rmtree(checkpoint_dir)
# print(f"Cleaned up checkpoint directory: {checkpoint_dir}")
# @title Solution 8: Orbax - Saving and Restoring Checkpoints
# 1. Use model and optimizer from previous exercise solution
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
print("Re-initializing model and optimizer from Ex7 solution for Ex8 solution.")
key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)
_model_din_sol_ex8 = 3
_model_dout_sol_ex8 = 2
_model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)
# Ensure SimpleNNXModel_Sol is used for the solution
model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,
dout=_model_dout_sol_ex8,
rngs=_model_rngs_sol_ex8)
_optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # Store the transform for later
optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)
print("Model and optimizer re-created for Ex8 solution.")
else
model_to_save_sol = my_model_sol_ex7
optimizer_to_save_sol = nnx_optimizer_sol_ex7
# We need the optax transform used to create the optimizer for restoration
_optax_tx_sol_ex8 = optimizer_to_save_sol.tx # Access the original Optax transform
print("Using model and optimizer from Ex7 solution for Ex8 solution.")
# 2. Define checkpoint directory
checkpoint_dir_sol = "/tmp/my_nnx_checkpoint_exercise_solution/"
if os.path.exists(checkpoint_dir_sol):
shutil.rmtree(checkpoint_dir_sol) # Clean up previous runs
os.makedirs(checkpoint_dir_sol, exist_ok=True)
print(f"Orbax checkpoint directory: {checkpoint_dir_sol}")
# 3. Create Orbax CheckpointManager
options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)
mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)
# 4. Bundle states
current_step_sol = 100 # Example step
_, model_state_to_save_sol = nnx.split(model_to_save_sol)
# The optimizer state is now a PyTree directly available in the .state attribute.
opt_state_to_save_sol = optimizer_to_save_sol.opt_state
save_bundle_sol = {
'model': model_state_to_save_sol,
'optimizer': opt_state_to_save_sol,
'step': current_step_sol
}
print("\nState bundle to be saved:")
pprint.pprint(f"Model state keys: {model_state_to_save_sol.keys()}")
pprint.pprint(f"Optimizer state type: {type(opt_state_to_save_sol)}")
# 5. Save the checkpoint
mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))
mngr_sol.wait_until_finished()
print(f"\nCheckpoint saved at step {current_step_sol} to {checkpoint_dir_sol}")
# --- Restoration ---
# 6.a Create new model and Optax transform (for restoration)
key_ex8_sol_restore_model, main_key = jax.random.split(main_key)
# Ensure din/dout are correctly obtained from the saved model's structure if possible
# Assuming model_to_save_sol is SimpleNNXModel_Sol which has a dense_layer
din_restore_sol = model_to_save_sol.dense_layer.in_features
dout_restore_sol = model_to_save_sol.dense_layer.out_features
restored_model_sol = SimpleNNXModel_Sol( # Use the solution's model class
din=din_restore_sol, dout=dout_restore_sol,
rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # New key for different initial params
# We need the original Optax transform definition for the new nnx.Optimizer
# _optax_tx_sol_ex8 was stored earlier, or can be re-created if config is known
restored_optax_tx_sol = _optax_tx_sol_ex8
# Print a param from new model BEFORE restoration to show it's different
_, kernel_before_restore_sol = nnx.split(restored_model_sol)
print(f"\nSample kernel from 'restored_model_sol' BEFORE restoration:")
nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])
# 6.b Restore the checkpoint
loaded_bundle_sol = None
latest_step_sol = mngr_sol.latest_step()
if latest_step_sol is not None
# For NNX, we are restoring raw PyTrees, StandardRestore is suitable.
loaded_bundle_sol = mngr_sol.restore(latest_step_sol,
args=ocp.args.StandardRestore(save_bundle_sol))
print(f"\nCheckpoint restored from step {latest_step_sol}")
print(f"Loaded bundle contains keys: {loaded_bundle_sol.keys()}")
else
raise ValueError("No checkpoint found to restore.")
# 6.c Apply loaded states
assert loaded_bundle_sol is not None, "Loaded bundle is None"
nnx.update(restored_model_sol, loaded_bundle_sol['model'])
print("Restored model state applied to 'restored_model_sol'.")
# Create new nnx.Optimizer with the restored_model and original optax_tx
restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,
wrt=nnx.Param)
# Now assign the loaded Optax state PyTree
restored_optimizer_sol.state = loaded_bundle_sol['optimizer']
print("Restored optimizer state applied to 'restored_optimizer_sol'.")
# 7. Verify restoration
original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
_, restored_model_state = nnx.split(restored_model_sol)
kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
"Model kernel parameters differ after restoration!"
print("\nModel parameters successfully restored and verified (kernel match).")
# Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)
original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
"Optimizer Adam mu for kernel differs!"
print("Optimizer state (sample mu) successfully restored and verified.")
# 8. Clean up
mngr_sol.close()
if os.path.exists(checkpoint_dir_sol):
shutil.rmtree(checkpoint_dir_sol)
print(f"Cleaned up checkpoint directory: {checkpoint_dir_sol}")
You've now had a hands-on introduction to:
This forms a strong foundation for developing high-performance machine learning models with the JAX ecosystem.
For further learning, refer to the official documentation:
Don't forget to provide feedback on the training session: https://goo.gle/jax-training-feedback