Introduction

Welcome to the JAX AI Stack Exercises!

This notebook is designed to accompany the "Leveraging the JAX AI Stack" lecture. You'll get hands-on experience with core JAX concepts, Flax NNX for model building, Optax for optimization, and Orbax for checkpointing.

The exercises will guide you through implementing key components, drawing parallels to PyTorch where appropriate, to solidify your understanding.

Let's get started!

# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import orbax.checkpoint as ocp # For Orbax
from typing import Any, Dict, Tuple # For type hints

# Helper to print PyTrees more nicely for demonstration
import pprint
import os # For Orbax directory management
import shutil # For cleaning up Orbax directory

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Orbax version: {ocp.__version__}")

# Global JAX PRNG key for reproducibility in exercises
# Students can learn to split this key for different operations.
main_key = jax.random.key(0)

Exercise 1: JAX Core & NumPy API

Goal: Get familiar with jax.numpy and JAX's functional programming style.

Instructions:

  1. Create two JAX arrays, a (a 2x2 matrix of random numbers) and b (a 2x2 matrix of ones) using jax.numpy (jnp). You'll need a jax.random.key for creating random numbers.
  2. Perform element-wise addition of a and b.
  3. Perform matrix multiplication of a and b.
  4. Demonstrate JAX's immutability:
- Store the Python id() of array a. - Perform an operation like a = a + 1. - Print the new id() of a and observe that it has changed, indicating a new array was created.
# Instructions for Exercise 1
key_ex1, main_key = jax.random.split(main_key) # Split the main key

# 1. Create JAX arrays a and b
# TODO: Create array 'a' (2x2 random normal) and 'b' (2x2 ones)
a = None # Placeholder
b = None # Placeholder

print("Array a:\n", a)
print("Array b:\n", b)

# 2. Perform element-wise addition
# TODO: Add a and b
c = None # Placeholder
print("Element-wise sum c = a + b:\n", c)

# 3. Perform matrix multiplication
# TODO: Matrix multiply a and b
d = None # Placeholder
print("Matrix product d = a @ b:\n", d)

# 4. Demonstrate immutability
# original_a_id = id(a)
# print(f"Original id(a): {original_a_id}")

# TODO: Perform an operation that reassigns 'a', e.g., a = a + 1
# a_new_ref = None # Placeholder
# new_a_id = id(a_new_ref)
# print(f"New id(a) after 'a = a + 1': {new_a_id}")

# TODO: Check if original_a_id is different from new_a_id
# print(f"IDs are different: {None}") # Placeholder
# @title Solution 1: JAX Core & NumPy API
key_ex1_sol, main_key = jax.random.split(main_key)

# 1. Create JAX arrays a and b
a_sol = jax.random.normal(key_ex1_sol, (2, 2))
b_sol = jnp.ones((2, 2))

print("Array a:\n", a_sol)
print("Array b:\n", b_sol)

# 2. Perform element-wise addition
c_sol = a_sol + b_sol
print("Element-wise sum c = a + b:\n", c_sol)

# 3. Perform matrix multiplication
d_sol = jnp.dot(a_sol, b_sol) # or d = a @ b
print("Matrix product d = a @ b:\n", d_sol)

# 4. Demonstrate immutability
original_a_id_sol = id(a_sol)
print(f"Original id(a_sol): {original_a_id_sol}")

a_sol_new_ref = a_sol + 1 # This creates a new array and rebinds the Python variable.
new_a_id_sol = id(a_sol_new_ref)
print(f"New id(a_sol_new_ref) after 'a_sol = a_sol + 1': {new_a_id_sol}")
print(f"IDs are different: {original_a_id_sol != new_a_id_sol}")
print("This shows that the original array was not modified in-place; a new array was created.")

Exercise 2: jax.jit (Just-In-Time Compilation)

Goal: Understand how to use jax.jit to compile JAX functions for performance.

Instructions:

  1. Define a Python function computeheavystuff(x, w, b) that performs a sequence of jnp operations:
- y = jnp.dot(x, w) - y = y + b - y = jnp.tanh(y) - result = jnp.sum(y) - Return result.
  1. Create a JIT-compiled version of this function, fastcomputeheavy_stuff, using jax.jit.
  2. Create some large dummy JAX arrays for x, w, and b.
  3. Call both the original and JIT-compiled functions with the dummy data.
  4. (Optional) Use the %timeit magic command in Colab (in separate cells) to compare their execution speeds. Remember that the first call to a JIT-compiled function includes compilation time.
# Instructions for Exercise 2
key_ex2_main, main_key = jax.random.split(main_key)
key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)

# 1. Define the Python function
def compute_heavy_stuff(x, w, b):
    # TODO: Implement the operations
    y1 = None # Placeholder
    y2 = None # Placeholder
    y3 = None # Placeholder
    result = None # Placeholder
    return result

# 2. Create a JIT-compiled version
# TODO: Use jax.jit to compile compute_heavy_stuff
fast_compute_heavy_stuff = None # Placeholder

# 3. Create dummy data
dim1, dim2, dim3 = 500, 1000, 500
x_data = jax.random.normal(key_ex2_x, (dim1, dim2))
w_data = jax.random.normal(key_ex2_w, (dim2, dim3))
b_data = jax.random.normal(key_ex2_b, (dim3,))

# 4. Call both functions
result_original = None # Placeholder compute_heavy_stuff(x_data, w_data, b_data)
result_fast_first_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # First call (compiles)
result_fast_second_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # Second call (uses compiled)

print(f"Result (original): {result_original}")
print(f"Result (fast, 1st call): {result_fast_first_call}")
print(f"Result (fast, 2nd call): {result_fast_second_call}")

# if result_original is not None and result_fast_first_call is not None:
#   assert jnp.allclose(result_original, result_fast_first_call), "Results should match!"
#   print("
Results from original and JIT-compiled functions match.")

# 5. Optional: Timing (use %timeit in separate cells for accuracy)
# print("
To see the speed difference, run these in separate cells:")
# print("%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# print("%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()")
# @title Solution 2: `jax.jit` (Just-In-Time Compilation)
key_ex2_sol_main, main_key = jax.random.split(main_key)
key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)

# 1. Define the Python function
def compute_heavy_stuff_sol(x, w, b):
    y = jnp.dot(x, w)
    y = y + b
    y = jnp.tanh(y)
    result = jnp.sum(y)
    return result

# 2. Create a JIT-compiled version
fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)

# 3. Create dummy data
dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500
x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))
w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))
b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))

# 4. Call both functions
# Call original once to ensure it's not timed with any JAX overhead if it were the first JAX op
result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

# First call to JITed function includes compilation time
result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

# Subsequent calls use the cached compiled code
result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()

print(f"Result (original): {result_original_sol}")
print(f"Result (fast, 1st call): {result_fast_sol_first_call}")
print(f"Result (fast, 2nd call): {result_fast_sol_second_call}")

assert jnp.allclose(result_original_sol, result_fast_sol_first_call), "Results should match!"
print("\nResults from original and JIT-compiled functions match.")

# 5. Optional: Timing
# To accurately measure, run these in separate Colab cells:
# Cell 1:
# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# Cell 2:
# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()
# You should observe that the JIT-compiled version is significantly faster after the initial compilation.
print("\nTo see the speed difference, run the %timeit commands (provided in comments above) in separate cells.")

Exercise 3: jax.grad (Automatic Differentiation)

Goal: Learn to use jax.grad to compute gradients of functions.

Instructions:

  1. Define a Python function scalarloss(params, x, ytrue) that:
- Takes a dictionary params with keys 'w' and 'b'. - Computes y_pred = params['w'] * x + params['b']. - Returns a scalar loss, e.g., jnp.mean((ypred - ytrue)**2).
  1. Use jax.grad to create a new function, computegradients, that computes the gradient of scalarloss with respect to its first argument (params).
  2. Initialize some dummy params, xinput, and ytarget values.
  3. Call compute_gradients to get the gradients. Print the gradients.
# Instructions for Exercise 3

# 1. Define the scalar_loss function
def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
    # TODO: Implement the prediction and loss calculation
    y_pred = None # Placeholder
    loss = None # Placeholder
    return loss

# 2. Create the gradient function using jax.grad
# TODO: Gradient of scalar_loss w.r.t. 'params' (argnums=0)
compute_gradients = None # Placeholder

# 3. Initialize dummy data
params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([7.0, 9.0, 11.0]) # Targets for y = 3x + 4 (to make non-zero loss with init_params)

# 4. Call the gradient function
gradients = None # Placeholder compute_gradients(params_init, x_input_data, y_target_data)
print("Initial params:", params_init)
print("Gradients w.r.t params:\n", gradients)

# Expected gradients (manual calculation for y_pred = wx+b, loss = mean((y_pred - y_true)^2)):
# dL/dw = mean(2 * (wx+b - y_true) * x)
# dL/db = mean(2 * (wx+b - y_true) * 1)
# For params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]
# x=1: y_pred = 2*1+1 = 3. Error = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8.  dL/db_i_term = 2*(-4)*1 = -8
# x=2: y_pred = 2*2+1 = 5. Error = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8
# x=3: y_pred = 2*3+1 = 7. Error = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8
# Mean gradients: dL/dw = (-8-16-24)/3 = -48/3 = -16.  dL/db = (-8-8-8)/3 = -24/3 = -8.
# if gradients is not None:
#     assert jnp.isclose(gradients['w'], -16.0)
#     assert jnp.isclose(gradients['b'], -8.0)
#     print("
Gradients match expected values.")
# @title Solution 3: `jax.grad` (Automatic Differentiation)

# 1. Define the scalar_loss function
def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
    y_pred = params['w'] * x + params['b']
    loss = jnp.mean((y_pred - y_true)**2)
    return loss

# 2. Create the gradient function using jax.grad
# Gradient of scalar_loss w.r.t. 'params' (which is the 0-th argument)
compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)

# 3. Initialize dummy data
params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_input_data_sol = jnp.array([1.0, 2.0, 3.0])
y_target_data_sol = jnp.array([7.0, 9.0, 11.0])

# 4. Call the gradient function
gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)
print("Initial params:", params_init_sol)
print("Gradients w.r.t params:\n", pprint.pformat(gradients_sol))

# Verify with expected values (calculated in instructions)
expected_dL_dw = -16.0
expected_dL_db = -8.0
assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f"Grad w.r.t 'w' is {gradients_sol['w']}, expected {expected_dL_dw}"
assert jnp.isclose(gradients_sol['b'], expected_dL_db), f"Grad w.r.t 'b' is {gradients_sol['b']}, expected {expected_dL_db}"
print("\nGradients match expected values.")

Exercise 4: jax.vmap (Automatic Vectorization)

Goal: Use jax.vmap to automatically batch operations.

Instructions:

  1. Define a function apply_affine(vector, matrix, bias) that takes a single 1D vector, a 2D matrix, and a 1D bias. It should compute jnp.dot(matrix, vector) + bias.
  2. You have a batch of vectors (a 2D array where each row is a vector), but a single matrix and a single bias that should be applied to each vector in the batch.
  3. Use jax.vmap to create batchedapplyaffine that efficiently applies apply_affine to each vector in the batch.
- Hint: inaxes for jax.vmap should specify 0 for the batched vector argument, and None for matrix and bias as they are not batched (broadcasted). The outaxes should be 0 to indicate the output is batched along the first axis.
  1. Test batchedapplyaffine with sample data.
# Instructions for Exercise 4
key_ex4_main, main_key = jax.random.split(main_key)
key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)

# 1. Define apply_affine for a single vector
def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    # TODO: Compute jnp.dot(matrix, vector) + bias
    result = None # Placeholder
    return result

# 2. Prepare data
batch_size = 4
input_features = 3
output_features = 2

# batch_of_vectors: (batch_size, input_features)
# single_matrix: (output_features, input_features)
# single_bias: (output_features,)
batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))
single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))
single_bias = jax.random.normal(key_ex4_bias, (output_features,))


# 3. Use jax.vmap to create batched_apply_affine
# TODO: Specify in_axes correctly: vector is batched, matrix and bias are not. out_axes should be 0.
batched_apply_affine = None # Placeholder jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)


# 4. Test batched_apply_affine
result_vmap = None # Placeholder batched_apply_affine(batch_of_vectors, single_matrix, single_bias)
print("Batch of vectors shape:", batch_of_vectors.shape)
print("Single matrix shape:", single_matrix.shape)
print("Single bias shape:", single_bias.shape)
if result_vmap is not None:
    print("Result using vmap shape:", result_vmap.shape) # Expected: (batch_size, output_features)

    # For comparison, a manual loop (less efficient):
    # manual_results = []
    # for i in range(batch_size):
    #     manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))
    # result_manual_loop = jnp.stack(manual_results)
    # assert jnp.allclose(result_vmap, result_manual_loop)
    # print("vmap result matches manual loop result.")
else:
    print("result_vmap is None.")
# @title Solution 4: `jax.vmap` (Automatic Vectorization)
key_ex4_sol_main, main_key = jax.random.split(main_key)
key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)

# 1. Define apply_affine for a single vector
def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return jnp.dot(matrix, vector) + bias

# 2. Prepare data
batch_size_sol = 4
input_features_sol = 3
output_features_sol = 2

batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))
single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))
single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))

# 3. Use jax.vmap to create batched_apply_affine
# Vector is batched along axis 0, matrix and bias are not batched (broadcasted).
# out_axes=0 means the output will also be batched along its first axis.
batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)

# 4. Test batched_apply_affine
result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)
print("Batch of vectors shape:", batch_of_vectors_sol.shape)
print("Single matrix shape:", single_matrix_sol.shape)
print("Single bias shape:", single_bias_sol.shape)
print("Result using vmap shape:", result_vmap_sol.shape) # Expected: (batch_size, output_features)
assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)

# For comparison, a manual loop (less efficient):
manual_results_sol = []
for i in range(batch_size_sol):
    manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))
result_manual_loop_sol = jnp.stack(manual_results_sol)

assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)
print("\nvmap result matches manual loop result, demonstrating correct vectorization.")

Exercise 5: Flax NNX - Defining a Model

Goal: Learn to define a simple neural network model using Flax NNX.

Instructions:

  1. Define a Flax NNX model class SimpleNNXModel that inherits from nnx.Module.
  2. In its init, define one nnx.Linear layer. The layer should take din (input features) and dout (output features) as arguments. Remember to pass the rngs argument to nnx.Linear for parameter initialization (e.g., rngs=rngs).
  3. Implement the call method (the forward pass) which takes an input x and passes it through the linear layer.
  4. Instantiate your SimpleNNXModel. You'll need to create an nnx.Rngs object using a JAX PRNG key (e.g., nnx.Rngs(params=jax.random.key(seed))). The key name params is conventional for nnx.Linear.
  5. Test your model instance with a dummy input batch. Print the output and the model's state (parameters) using nnx.display().
# Instructions for Exercise 5
key_ex5_model_init, main_key = jax.random.split(main_key)

# 1. & 2. & 3. Define the SimpleNNXModel
class SimpleNNXModel(nnx.Module):
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        # TODO: Define an nnx.Linear layer named 'dense_layer'
        # self.dense_layer = nnx.Linear(...)
        self.some_attribute = None # Placeholder, remove later
        pass # Remove this placeholder if class is not empty

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # TODO: Pass input x through the dense_layer
        # return self.dense_layer(x)
        return x # Placeholder

# 4. Instantiate the model
model_din = 3
model_dout = 2
# TODO: Create nnx.Rngs for parameter initialization. Use 'params' as the key name.
model_rngs = None # Placeholder nnx.Rngs(params=key_ex5_model_init)
my_model = None # Placeholder SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)

# 5. Test with dummy data
dummy_batch_size = 4
dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))

model_output = None # Placeholder
if my_model is not None:
    model_output = my_model(dummy_input_ex5)
    print(f"Model output shape: {model_output.shape}")
    print(f"Model output:\n{model_output}")

    model_state = my_model.get_state()
    print(f"\nModel state (parameters, etc.):")
    pprint.pprint(model_state)
else:
    print("my_model is None.")
# @title Solution 5: Flax NNX - Defining a Model
key_ex5_sol_model_init, main_key = jax.random.split(main_key)

# 1. & 2. & 3. Define the SimpleNNXModel
class SimpleNNXModel_Sol(nnx.Module): # Renamed for solution cell
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        # nnx.Linear will use the 'params' key from rngs by default for its parameters
        self.dense_layer = nnx.Linear(din, dout, rngs=rngs)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.dense_layer(x)

# 4. Instantiate the model
model_din_sol = 3
model_dout_sol = 2
# Create nnx.Rngs for parameter initialization.
# 'params' is the default key nnx.Linear looks for in the rngs object.
model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)
my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)

# 5. Test with dummy data
dummy_batch_size_sol = 4
dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))

model_output_sol = my_model_sol(dummy_input_ex5_sol)
print(f"Model output shape: {model_output_sol.shape}")
print(f"Model output:\n{model_output_sol}")

# model_state_sol = my_model_sol.get_state()
_, model_state_sol = nnx.split(my_model_sol)
print(f"\nModel state (parameters, etc.):")
nnx.display(model_state_sol)

# Check that parameters are present
assert 'dense_layer' in model_state_sol, "Key 'dense_layer' not in model_state"
assert 'kernel' in model_state_sol['dense_layer'], "Key 'kernel' not in model_state['dense_layer']"
assert 'bias' in model_state_sol['dense_layer'], "Key 'bias' not in model_state['dense_layer']"
print("\nModel parameters (kernel and bias for dense_layer) are present in the state.")

Exercise 6: Optax & Flax NNX - Creating an Optimizer

Goal: Set up an Optax optimizer and wrap it with nnx.Optimizer for use with a Flax NNX model.

Instructions:

  1. Use the SimpleNNXModelSol class and an instance mymodel_sol from the previous exercise's solution. (If running standalone, re-instantiate it).
  2. Create an Optax optimizer, for example, optax.adam with a learning rate of 0.001.
  3. Create an nnx.Optimizer instance. This wrapper links the Optax optimizer with your Flax NNX model (mymodelsol).
  4. Print the nnx.Optimizer instance and its state attribute to see the initialized optimizer state (e.g., Adam's momentum terms).
# Instructions for Exercise 6

# 1. Assume my_model_sol is available from Exercise 5 solution
# (If running standalone, re-instantiate it)
if 'my_model_sol' not in globals():
    print("Re-initializing model from Ex5 solution for Ex6.")
    key_ex6_model_init, main_key = jax.random.split(main_key)
    _model_din_ex6 = 3
    _model_dout_ex6 = 2
    _model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)
    # Use solution class name if defined, otherwise student's class name
    _ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
    model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)
    print("Model for optimizer created.")
else:
    model_for_opt = my_model_sol # Use the one from previous solution
    print("Using model 'my_model_sol' from previous exercise for 'model_for_opt'.")


# 2. Create an Optax optimizer
learning_rate = 0.001
# TODO: Create an optax.adam optimizer transform
optax_tx = None # Placeholder optax.adam(...)

# 3. Create an nnx.Optimizer wrapper
# TODO: Wrap the model (model_for_opt) and the optax transform (optax_tx)
# The `wrt` argument is now required to specify what to differentiate with respect to.
nnx_optimizer = None # Placeholder nnx.Optimizer(...)

# 4. Print the optimizer and its state
print("\nFlax NNX Optimizer wrapper:")
nnx.display(nnx_optimizer)

print("\nInitial Optimizer State (Optax state, e.g., Adam's momentum):")
if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):
   pprint.pprint(nnx_optimizer.state)
   # if hasattr(nnx_optimizer, 'opt_state'):
   #     adam_state = nnx_optimizer.opt_state
   #     assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')
   #     print("
Optimizer state structure looks plausible for Adam.")
else:
    print("nnx_optimizer or its state is None or not structured as expected.")
# @title Solution 6: Optax & Flax NNX - Creating an Optimizer

# 1. Use my_model_sol from Exercise 5 solution
# If not run sequentially, ensure my_model_sol is defined:
if 'my_model_sol' not in globals():
    print("Re-initializing model from Ex5 solution for Ex6.")
    key_ex6_sol_model_init, main_key = jax.random.split(main_key)
    _model_din_sol_ex6 = 3
    _model_dout_sol_ex6 = 2
    _model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)
    # Ensure SimpleNNXModel_Sol is used
    my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)
    print("Model for optimizer re-created as 'my_model_sol'.")
else
    print("Using model 'my_model_sol' from previous exercise.")


# 2. Create an Optax optimizer
learning_rate_sol = 0.001
# Create an optax.adam optimizer transform
optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)

# 3. Create an nnx.Optimizer wrapper
# This links the model and the Optax optimizer.
# The optimizer state will be initialized based on the model's parameters.
nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)

# 4. Print the optimizer and its state
print("\nFlax NNX Optimizer wrapper:")
nnx.display(nnx_optimizer_sol) # Shows the model it's associated with and the Optax transform

print("\nInitial Optimizer State (Optax state, e.g., Adam's momentum):")
# nnx.Optimizer stores the actual Optax state in its .opt_state attribute.
# This state is a PyTree that matches the structure of the model's parameters.
pprint.pprint(nnx_optimizer_sol.opt_state)

# Verify the structure of the optimizer state for Adam (count, mu, nu for each param)
assert hasattr(nnx_optimizer_sol, 'opt_state'), "Optax opt_state not found in nnx.Optimizer"
# The opt_state is a tuple, typically (CountState(), ScaleByAdamState()) for adam
adam_optax_internal_state = nnx_optimizer_sol.opt_state
assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), "Adam 'count' state not found."
# The second element of the tuple is often where parameter-specific states like mu and nu reside
if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):
    param_specific_state = adam_optax_internal_state[1]
    assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], "Adam 'mu' state for kernel not found."
    print("\nOptimizer state structure looks correct for Adam.")
else
    print("\nWarning: Optimizer state structure for Adam might be different or not fully verified.")

Exercise 7: Training Step with Flax NNX and Optax

Goal: Implement a complete JIT-compiled training step for a Flax NNX model using Optax.

Instructions:

  1. You'll need:
- An instance of your model class (e.g., mymodelsol from Ex 5/6 solution). - An instance of nnx.Optimizer (e.g., nnxoptimizersol from Ex 6 solution).
  1. Define a trainstep function that is decorated with @nnx.jit. This function should take the model, optimizer, input xbatch, and target y_batch as arguments.
  2. Inside train_step:
- Define an inner lossfnforgrad. This function must take the model as its first argument. Inside, it computes the model's predictions for xbatch and then calculates the mean squared error (MSE) against y_batch. - Use nnx.valueandgrad(lossfnforgrad)(modelarg) to compute both the loss value and the gradients with respect to the model passed to lossfnforgrad. (modelarg is the model instance passed into train_step). - Update the model's parameters (and the optimizer's state) using optimizerarg.update(modelarg, grads). The update method takes the model and gradients, and updates the model's state in-place. - Return the computed loss_value.
  1. Create dummy xbatch and ybatch data.
  2. Call your train_step function. Print the returned loss.
  3. (Optional) Verify that the model's parameters have changed after the train_step by comparing a parameter value before and after the call.
# Instructions for Exercise 7
key_ex7_main, main_key = jax.random.split(main_key)
key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)

# 1. Use model and optimizer from previous exercises' solutions
# Ensure my_model_sol and nnx_optimizer_sol are available
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
    print("Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7.")
    key_ex7_model_fallback, main_key = jax.random.split(main_key)
    _model_din_ex7 = 3
    _model_dout_ex7 = 2
    _model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)
    # Ensure SimpleNNXModel_Sol is used
    my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)
    _optax_tx_ex7 = optax.adam(learning_rate=0.001)
    nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)
    print("Model and optimizer re-created for Ex7.")
else
    my_model_ex7 = my_model_sol
    nnx_optimizer_ex7 = nnx_optimizer_sol
    print("Using 'my_model_sol' and 'nnx_optimizer_sol' for 'my_model_ex7' and 'nnx_optimizer_ex7'.")


# 2. & 3. Define the train_step function
# TODO: Decorate with @nnx.jit
# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Type hint with base nnx.Module
#                x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:

    # TODO: Define inner loss_fn_for_grad(current_model_state_for_grad_fn)
    # def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Type hint with base nnx.Module
        # y_pred = model_in_grad_fn(x_batch)
        # loss = jnp.mean((y_pred - y_batch)**2)
        # return loss
    #    return jnp.array(0.0) # Placeholder

    # TODO: Compute loss value and gradients using nnx.value_and_grad
    # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # Pass model_arg

    # TODO: Update the optimizer (which updates the model_arg in-place)
    # optimizer_arg.update(model_arg, grads)

    # return loss_value
#    return jnp.array(0.0) # Placeholder defined train_step function

# For the student to define:
# Make sure the function signature is correct for nnx.jit
@nnx.jit
def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
               x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:
    # Placeholder implementation for student
    def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
        # y_pred = model_in_grad_fn(x_batch)
        # loss = jnp.mean((y_pred - y_batch)**2)
        # return loss
        return jnp.array(0.0) # Student TODO: replace this

    # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)
    # optimizer_arg.update(grads)
    # return loss_value
    return jnp.array(-1.0) # Student TODO: replace this


# 4. Create dummy data
batch_s = 8
# Access features_in and features_out carefully
_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3
_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2

x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))
y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))

# Optional: Store initial param value for comparison
initial_kernel_val = None
if hasattr(my_model_ex7, 'get_state'):
    _current_model_state_ex7 = my_model_ex7.get_state()
    if 'dense_layer' in _current_model_state_ex7
       initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()
print(f"Initial kernel value (sample): {initial_kernel_val}")

# 5. Call the train_step
# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # Student will uncomment
loss_after_step = jnp.array(-1.0) # Placeholder until student implements train_step
if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # Check if student implemented
    loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)
    print(f"Loss after one training step: {loss_after_step}")
else
    print("Student needs to implement `train_step` function.")


# # 6. Optional: Verify parameter change
# updated_kernel_val_sol = None
# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update
# if 'dense_layer' in updated_model_state_sol:
#   updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
#   print(f"Updated kernel value (sample): {updated_kernel_val_sol}")

# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:
#     assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "Kernel parameter did not change!"
#     print("Kernel parameter changed as expected after the training step.")
# else:
#     print("Could not verify kernel change (initial or updated value was None).")
# @title Solution 7: Training Step with Flax NNX and Optax
key_ex7_sol_main, main_key = jax.random.split(main_key)
key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)

# 1. Use model and optimizer from previous exercises' solutions
if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():
    print("Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7 solution.")
    key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)
    _model_din_sol_ex7 = 3
    _model_dout_sol_ex7 = 2
    _model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)
    # Ensure SimpleNNXModel_Sol is used for the solution
    my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)
    _optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)
    nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)
    print("Model and optimizer re-created for Ex7 solution.")
else
    # If solutions are run sequentially, these will be the correct instances
    my_model_sol_ex7 = my_model_sol
    nnx_optimizer_sol_ex7 = nnx_optimizer_sol
    print("Using 'my_model_sol' and 'nnx_optimizer_sol' for Ex7 solution.")


# 2. & 3. Define the train_step function
@nnx.jit # Decorate with @nnx.jit for JIT compilation
def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,
                   x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:

    # Define inner loss_fn_for_grad. It takes the model as its first argument.
    # It captures x_batch and y_batch from the outer scope.
    def loss_fn_for_grad(model_in_grad_fn: nnx.Module):
        y_pred = model_in_grad_fn(x_batch) # Use the model passed to this inner function
        loss = jnp.mean((y_pred - y_batch)**2)
        return loss

    # Compute loss value and gradients using nnx.value_and_grad.
    # This will differentiate loss_fn_for_grad with respect to its first argument (model_in_grad_fn).
    # We pass the current state of our model (model_arg) to it.
    loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)

    # Update the optimizer. This updates the model_arg (which nnx_optimizer_sol_ex7 references) in-place.
    optimizer_arg.update(model_arg, grads)

    return loss_value


# 4. Create dummy data
batch_s_sol = 8
# Ensure din and dout match the model instantiation from Ex5/Ex6
# my_model_sol_ex7.dense_layer is an nnx.Linear object
din_from_model_sol = my_model_sol_ex7.dense_layer.in_features
dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features

x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))
y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))

# Optional: Store initial param value for comparison
initial_kernel_val_sol = None
_, current_model_state_sol = nnx.split(my_model_sol_ex7)
if 'dense_layer' in current_model_state_sol
    initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()
print(f"Initial kernel value (sample): {initial_kernel_val_sol}")


# 5. Call the train_step
# First call will JIT compile the train_step_sol function.
loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"Loss after one training step (1st call, JIT): {loss_after_step_sol}")
# Second call to show it's faster (though %timeit is better for measurement)
loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)
print(f"Loss after one training step (2nd call, cached): {loss_after_step_sol_2}")


# 6. Optional: Verify parameter change
updated_kernel_val_sol = None
_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update
if 'dense_layer' in updated_model_state_sol
  updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]
  print(f"Updated kernel value (sample): {updated_kernel_val_sol}")

if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None
    assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), "Kernel parameter did not change!"
    print("Kernel parameter changed as expected after the training step.")
else
    print("Could not verify kernel change (initial or updated value was None).")

Exercise 8: Orbax - Saving and Restoring Checkpoints

Goal: Learn to use Orbax to save and restore JAX PyTrees, specifically Flax NNX model states and Optax optimizer states.

Instructions:

  1. You'll need your model (e.g., mymodelsolex7) and optimizer (e.g., nnxoptimizersolex7) from the previous exercise's solution.
  2. Define a checkpoint directory (e.g., /tmp/mynnxcheckpoint/).
  3. Create an Orbax CheckpointManagerOptions and then a CheckpointManager.
  4. Bundle the states you want to save into a dictionary. For NNX, this is mymodelsolex7.getstate() for the model, and nnxoptimizersol_ex7.state for the optimizer's internal state. Also include a training step counter.
  5. Use checkpointmanager.save() with ocp.args.StandardSave() to save the bundled state. Call checkpointmanager.waituntilfinished() to ensure saving completes.
  6. To restore:
- Create new instances of your model (restoredmodel) and Optax transform (restoredoptax_tx). The new model should have a different PRNG key for its initial parameters to demonstrate that restoration works. - Use checkpoint_manager.restore() with ocp.args.StandardRestore() to load the bundled state. - Apply the loaded model state to restoredmodel using restoredmodel.updatestate(loadedbundle['model']). - Create a new nnx.Optimizer (restoredoptimizer) associating restoredmodel and restoredoptaxtx. - Assign the loaded optimizer state to the new optimizer: restoredoptimizer.state = loadedbundle['optimizer'].
  1. Verify that a parameter from restoredmodel matches the corresponding parameter from the original mymodelsolex7 (before saving, or from the saved state). Also, compare optimizer states if possible.
  2. Clean up the checkpoint directory.
# Instructions for Exercise 8
# import orbax.checkpoint as ocp # Already imported
# import os, shutil # Already imported

# 1. Use model and optimizer from previous exercise solution
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
    print("Re-initializing model and optimizer from Ex7 solution for Ex8.")
    key_ex8_model_fallback, main_key = jax.random.split(main_key)
    _model_din_ex8 = 3
    _model_dout_ex8 = 2
    _model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)
    _ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
    model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)
    _optax_tx_ex8 = optax.adam(learning_rate=0.001)
    optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)
    print("Model and optimizer re-created for Ex8.")
else
    model_to_save = my_model_sol_ex7
    optimizer_to_save = nnx_optimizer_sol_ex7
    print("Using model and optimizer from Ex7 solution for Ex8.")

# 2. Define checkpoint directory
# TODO: Define checkpoint_dir
checkpoint_dir = None # Placeholder e.g., "/tmp/my_nnx_checkpoint_exercise/"
# if checkpoint_dir and os.path.exists(checkpoint_dir):
#    shutil.rmtree(checkpoint_dir) # Clean up previous runs for safety
# if checkpoint_dir:
#    os.makedirs(checkpoint_dir, exist_ok=True)


# 3. Create Orbax CheckpointManager
# TODO: Create options and manager
# options = ocp.CheckpointManagerOptions(...)
# mngr = ocp.CheckpointManager(...)
options = None
mngr = None

# 4. Bundle states
# current_step = 100 # Example step
# TODO: Get model_state and optimizer_state
# model_state_to_save = nnx.split(model_to_save)
# The optimizer state is now accessed via the .state attribute.
# opt_state_to_save = optimizer_to_save.state
# save_bundle = {
#     'model': model_state_to_save,
#     'optimizer': opt_state_to_save,
#     'step': current_step
# }
save_bundle = None

# 5. Save the checkpoint
# if mngr and save_bundle:
#   TODO: Save checkpoint
#   mngr.save(...)
#   mngr.wait_until_finished()
#   print(f"Checkpoint saved at step {current_step} to {checkpoint_dir}")
# else:
#   print("Checkpoint manager or save_bundle not initialized.")

# --- Restoration ---
# 6.a Create new model and Optax transform (for restoration)
# key_ex8_restore_model, main_key = jax.random.split(main_key)
# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3
# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2
# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel
# restored_model = _ModelClassRestore(
#     din=din_restore, dout=dout_restore,
#     rngs=nnx.Rngs(params=key_ex8_restore_model) # New key for different initial params
# )
# restored_optax_tx = optax.adam(learning_rate=0.001) # Same Optax config
restored_model = None
restored_optax_tx = None

# 6.b Restore the checkpoint
# loaded_bundle = None
# if mngr:
#   TODO: Restore checkpoint
#   latest_step = mngr.latest_step()
#   if latest_step is not None:
#       loaded_bundle = mngr.restore(...)
#       print(f"Checkpoint restored from step {latest_step}")
#   else:
#       print("No checkpoint found to restore.")
# else:
#   print("Checkpoint manager not initialized for restore.")

# 6.c Apply loaded states
# if loaded_bundle and restored_model:
#   TODO: Update restored_model state
#   nnx.update(restored_model, ...)
#   print("Restored model state applied.")

    # TODO: Create new nnx.Optimizer and assign its state
#   restored_optimizer = nnx.Optimizer(...)
#   restored_optimizer.state = ...
#   print("Restored optimizer state applied.")
# else:
#   print("Loaded_bundle or restored_model is None, cannot apply states.")
restored_optimizer = None

# 7. Verify restoration
# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
# _, restored_model_state = nnx.split(restored_model_sol)
# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), 
#        "Model kernel parameters differ after restoration!"
# print("
Model parameters successfully restored and verified (kernel match).")

# # Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)
# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), 
#                        "Optimizer Adam mu for kernel differs!"
# print("Optimizer state (sample mu) successfully restored and verified.")


# 8. Clean up
# if mngr:
#   mngr.close()
# if checkpoint_dir and os.path.exists(checkpoint_dir):
#   shutil.rmtree(checkpoint_dir)
#   print(f"Cleaned up checkpoint directory: {checkpoint_dir}")
# @title Solution 8: Orbax - Saving and Restoring Checkpoints

# 1. Use model and optimizer from previous exercise solution
if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():
    print("Re-initializing model and optimizer from Ex7 solution for Ex8 solution.")
    key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)
    _model_din_sol_ex8 = 3
    _model_dout_sol_ex8 = 2
    _model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)
    # Ensure SimpleNNXModel_Sol is used for the solution
    model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,
                                           dout=_model_dout_sol_ex8,
                                           rngs=_model_rngs_sol_ex8)
    _optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # Store the transform for later
    optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)
    print("Model and optimizer re-created for Ex8 solution.")
else
    model_to_save_sol = my_model_sol_ex7
    optimizer_to_save_sol = nnx_optimizer_sol_ex7
    # We need the optax transform used to create the optimizer for restoration
    _optax_tx_sol_ex8 = optimizer_to_save_sol.tx # Access the original Optax transform
    print("Using model and optimizer from Ex7 solution for Ex8 solution.")

# 2. Define checkpoint directory
checkpoint_dir_sol = "/tmp/my_nnx_checkpoint_exercise_solution/"
if os.path.exists(checkpoint_dir_sol):
   shutil.rmtree(checkpoint_dir_sol) # Clean up previous runs
os.makedirs(checkpoint_dir_sol, exist_ok=True)
print(f"Orbax checkpoint directory: {checkpoint_dir_sol}")

# 3. Create Orbax CheckpointManager
options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)
mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)

# 4. Bundle states
current_step_sol = 100 # Example step
_, model_state_to_save_sol = nnx.split(model_to_save_sol)
# The optimizer state is now a PyTree directly available in the .state attribute.
opt_state_to_save_sol = optimizer_to_save_sol.opt_state
save_bundle_sol = {
    'model': model_state_to_save_sol,
    'optimizer': opt_state_to_save_sol,
    'step': current_step_sol
}
print("\nState bundle to be saved:")
pprint.pprint(f"Model state keys: {model_state_to_save_sol.keys()}")
pprint.pprint(f"Optimizer state type: {type(opt_state_to_save_sol)}")


# 5. Save the checkpoint
mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))
mngr_sol.wait_until_finished()
print(f"\nCheckpoint saved at step {current_step_sol} to {checkpoint_dir_sol}")

# --- Restoration ---
# 6.a Create new model and Optax transform (for restoration)
key_ex8_sol_restore_model, main_key = jax.random.split(main_key)
# Ensure din/dout are correctly obtained from the saved model's structure if possible
# Assuming model_to_save_sol is SimpleNNXModel_Sol which has a dense_layer
din_restore_sol = model_to_save_sol.dense_layer.in_features
dout_restore_sol = model_to_save_sol.dense_layer.out_features

restored_model_sol = SimpleNNXModel_Sol( # Use the solution's model class
    din=din_restore_sol, dout=dout_restore_sol,
    rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # New key for different initial params
# We need the original Optax transform definition for the new nnx.Optimizer
# _optax_tx_sol_ex8 was stored earlier, or can be re-created if config is known
restored_optax_tx_sol = _optax_tx_sol_ex8

# Print a param from new model BEFORE restoration to show it's different
_, kernel_before_restore_sol = nnx.split(restored_model_sol)
print(f"\nSample kernel from 'restored_model_sol' BEFORE restoration:")
nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])

# 6.b Restore the checkpoint
loaded_bundle_sol = None
latest_step_sol = mngr_sol.latest_step()
if latest_step_sol is not None
    # For NNX, we are restoring raw PyTrees, StandardRestore is suitable.
    loaded_bundle_sol = mngr_sol.restore(latest_step_sol,
                                         args=ocp.args.StandardRestore(save_bundle_sol))
    print(f"\nCheckpoint restored from step {latest_step_sol}")
    print(f"Loaded bundle contains keys: {loaded_bundle_sol.keys()}")
else
    raise ValueError("No checkpoint found to restore.")

# 6.c Apply loaded states
assert loaded_bundle_sol is not None, "Loaded bundle is None"
nnx.update(restored_model_sol, loaded_bundle_sol['model'])
print("Restored model state applied to 'restored_model_sol'.")

# Create new nnx.Optimizer with the restored_model and original optax_tx
restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,
                                       wrt=nnx.Param)
# Now assign the loaded Optax state PyTree
restored_optimizer_sol.state = loaded_bundle_sol['optimizer']
print("Restored optimizer state applied to 'restored_optimizer_sol'.")


# 7. Verify restoration
original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']
_, restored_model_state = nnx.split(restored_model_sol)
kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']
assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \
       "Model kernel parameters differ after restoration!"
print("\nModel parameters successfully restored and verified (kernel match).")

# Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)
original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value
restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value
assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \
                       "Optimizer Adam mu for kernel differs!"
print("Optimizer state (sample mu) successfully restored and verified.")


# 8. Clean up
mngr_sol.close()
if os.path.exists(checkpoint_dir_sol):
  shutil.rmtree(checkpoint_dir_sol)
  print(f"Cleaned up checkpoint directory: {checkpoint_dir_sol}")

Conclusion

Congratulations on completing the JAX AI Stack exercises!

You've now had a hands-on introduction to:

  • Core JAX: jax.numpy, functional programming, jax.jit, jax.grad, jax.vmap.
  • Flax NNX: Defining and instantiating Pythonic neural network models.
  • Optax: Creating and using composable optimizers with Flax NNX.
  • Training Loop: Implementing an end-to-end training step in Flax NNX.
  • Orbax: Saving and restoring model and optimizer states.
  • This forms a strong foundation for developing high-performance machine learning models with the JAX ecosystem.

    For further learning, refer to the official documentation:

    • JAX AI Stack: https://jaxstack.ai
    • JAX: https://jax.dev
    • Flax NNX: https://flax.readthedocs.io
    • Optax: https://optax.readthedocs.io
    • Orbax: https://orbax.readthedocs.io
    • Don't forget to provide feedback on the training session: https://goo.gle/jax-training-feedback