Optax with Flax NNX: Exercises for PyTorch Users

This Colab notebook contains a series of exercises designed to help you, a PyTorch user, get hands-on experience with Optax, the primary optimization library in the JAX ecosystem, specifically for training Flax NNX models. We will cover everything from the basics of setting up an optimizer to advanced techniques like learning rate scheduling, per-parameter optimization, and sharding for distributed training.

Setup

First, let's install the necessary libraries and set up a simulated multi-device environment. We'll use chex to simulate having 8 CPU devices, which will allow us to explore distributed training concepts without needing multiple physical GPUs/TPUs.
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import chex
import optax
import flax
from flax import nnx

# Simulate an environment with 8 CPU devices for sharding exercises
try:
  chex.set_n_cpu_devices(8)
except RuntimeError as e:
  print(f"Could not set n_cpu_devices: {e}")
  print("Sharding exercises may not work as intended. Continuing anyway.")

# Helper to check available devices
print(f"JAX is running on: {jax.default_backend()}")
print(f"Number of available devices: {jax.device_count()}")
print(f"Device details: {jax.devices()}")
print(f"Flax version: {flax.__version__}")

Exercise 1: The Basic Training Loop

Concept: This exercise covers the fundamental workflow of training a Flax NNX model with Optax. You will:
  1. Define a simple MLP model using flax.nnx.Module.
  2. Instantiate the model and a basic optax.adam optimizer using flax.nnx.Optimizer.
  3. Write a Mean Squared Error (MSE) loss function.
  4. Create a complete, JIT-compiled training step function that takes the model and optimizer as arguments, calculates the loss, computes gradients using flax.nnx.valueandgrad, and updates the model parameters using optimizer.update(model, grads).

This process mirrors the standard "instantiate, calculate loss, backpropagate, step" cycle in PyTorch but introduces the JAX/Optax equivalents: nnx.Optimizer, nnx.valueandgrad, and optimizer.update().

Instructions

Complete the TODO sections in the following code cell to implement the basic training loop.

# @title Exercise 1: Implement the Basic Training Loop
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from typing import Sequence

# 1. Define the Model
class SimpleMLP(nnx.Module):
  """A simple Multi-Layer Perceptron."""
  def __init__(self, features: Sequence[int], *, rngs: nnx.Rngs):
    self.layers = []
    for i in range(len(features) - 1):
        self.layers.append(nnx.Linear(features[i], features[i+1], rngs=rngs))
        if i < len(features) - 2:
            self.layers.append(nnx.relu)

  def __call__(self, x: jax.Array):
    for layer in self.layers:
        x = layer(x)
    return x

# 2. Define the Loss Function
def mse_loss(model: SimpleMLP, x_batch: jax.Array, y_batch: jax.Array) -> jax.Array:
  """Calculates the Mean Squared Error loss."""
  # TODO: Get predictions from the model and calculate the MSE.
  # Hint: The model is callable, e.g., model(x_batch).
  # YOUR CODE HERE
  return loss

# 3. Define the Training Step
@nnx.jit
def train_step(model: SimpleMLP, optimizer: nnx.Optimizer, x_batch: jax.Array, y_batch: jax.Array):
  """Performs a single training step."""
  # TODO: Use nnx.value_and_grad to get both the loss and the gradients.
  # You'll need a loss function closure that takes only the model as an argument.
  def loss_fn_for_grad(model_to_train):
      return mse_loss(model_to_train, x_batch, y_batch)

  loss_val, grads = # YOUR CODE HERE

  # TODO: Update the optimizer with the gradients.
  # YOUR CODE HERE

  # The optimizer's state is modified in-place by update(), but under jit,
  # we must return it to get the new state out.
  return model, optimizer, loss_val

# --- Boilerplate for running the exercise ---
# Create dummy data
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

# Instantiate model and optimizer
model = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))
opt = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)

# Training Loop
print("Starting basic training loop...")
for i in range(101):
  model, optimizer, loss = train_step(model, optimizer, x_dummy, y_dummy)
  if i % 20 == 0:
    # The .value attribute is used to get the raw value from a State variable
    print(f"Step {optimizer.step.value}, Loss: {loss:.4f}")
print("Basic training loop finished.")
# Verify the model parameters have been updated
assert optimizer.step.value == 101
# @title Solution 1
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from typing import Sequence

# 1. Define the Model
class SimpleMLP(nnx.Module):
  """A simple Multi-Layer Perceptron."""
  def __init__(self, features: Sequence[int], *, rngs: nnx.Rngs):
    self.layers = []
    for i in range(len(features) - 1):
        self.layers.append(nnx.Linear(features[i], features[i+1], rngs=rngs))
        if i < len(features) - 2:
            self.layers.append(nnx.relu)

  def __call__(self, x: jax.Array):
    for layer in self.layers:
        x = layer(x)
    return x

# 2. Define the Loss Function
def mse_loss(model: SimpleMLP, x_batch: jax.Array, y_batch: jax.Array) -> jax.Array:
  """Calculates the Mean Squared Error loss."""
  predictions = model(x_batch)
  loss = jnp.mean((predictions - y_batch) ** 2)
  return loss

# 3. Define the Training Step
@nnx.jit
def train_step(model: SimpleMLP, optimizer: nnx.Optimizer, x_batch: jax.Array, y_batch: jax.Array):
  """Performs a single training step."""
  # A closure to capture the current batch of data
  def loss_fn_for_grad(model_to_train: SimpleMLP):
    return mse_loss(model_to_train, x_batch, y_batch)

  # Compute loss and gradients
  loss_val, grads = nnx.value_and_grad(loss_fn_for_grad)(model)

  # Update the optimizer's state and model parameters
  optimizer.update(model, grads)

  return model, optimizer, loss_val

# --- Boilerplate for running the exercise ---
# Create dummy data
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

# Instantiate model and optimizer
model = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))
opt = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)

# Training Loop
print("Starting basic training loop...")
for i in range(101):
  model, optimizer, loss = train_step(model, optimizer, x_dummy, y_dummy)
  if i % 20 == 0:
    print(f"Step {optimizer.step.value}, Loss: {loss:.4f}")
print("Basic training loop finished.")
assert optimizer.step.value == 101

Exercise 2: Composing Gradient Transformations

Concept: A core philosophy of Optax is composability. Instead of monolithic optimizers, Optax provides small, chainable "gradient transformations." This exercise demonstrates how to build a custom optimization pipeline by chaining multiple transformations together.

You will add gradient clipping and weight decay to the Adam optimizer, creating a more robust optimization rule. This is analogous to combining features that might be built-in flags in a PyTorch optimizer, but here you explicitly build the chain.

Instructions

Complete the TODO section to create a chained Optax transformation.

# @title Exercise 2: Build a Chained Optimizer
import jax
import jax.numpy as jnp
import optax
from flax import nnx

# --- Using the same model and data setup from Exercise 1 ---
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

model_chained = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))

# Define hyperparameters
learning_rate = 1e-3
max_grad_norm = 1.0
weight_decay = 1e-4

# TODO: Create a chained Optax transformation.
# The desired order is:
# 1. Clip gradients by their global norm (optax.clip_by_global_norm).
# 2. Add weight decay (optax.add_decayed_weights).
# 3. Apply Adam optimizer updates (optax.adam).
# Hint: Use optax.chain([...])
opt_chained = optax.chain(
    # YOUR CODE HERE
)


# --- Boilerplate for running the exercise ---
# The train_step and mse_loss from Exercise 1 can be reused directly!
optimizer_chained = nnx.Optimizer(model_chained, opt_chained, wrt=nnx.Param)

print("Starting training with chained optimizer...")
for i in range(101):
  model_chained, optimizer_chained, loss = train_step(model_chained, optimizer_chained, x_dummy, y_dummy)
  if i % 20 == 0:
    print(f"Step {optimizer_chained.step.value}, Loss: {loss:.4f}")
print("Chained optimizer training finished.")
# @title Solution 2
import jax
import jax.numpy as jnp
import optax
from flax import nnx

# --- Using the same model and data setup from Exercise 1 ---
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

model_chained = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))

# Define hyperparameters
max_grad_norm = 1.0
weight_decay = 1e-4
learning_rate = 1e-3

# Create a chained Optax transformation
opt_chained = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.add_decayed_weights(weight_decay),
    optax.adam(learning_rate)
)

# --- Boilerplate for running the exercise ---
# The train_step and mse_loss from Exercise 1 can be reused directly!
optimizer_chained = nnx.Optimizer(model_chained, opt_chained, wrt=nnx.Param)

print("Starting training with chained optimizer...")
for i in range(101):
  model_chained, optimizer_chained, loss = train_step(model_chained, optimizer_chained, x_dummy, y_dummy)
  if i % 20 == 0:
    print(f"Step {optimizer_chained.step.value}, Loss: {loss:.4f}")
print("Chained optimizer training finished.")

Exercise 3: Learning Rate Scheduling

Concept: Dynamically adjusting the learning rate during training is a crucial technique. In Optax, you don't use an external scheduler.step() like in PyTorch. Instead, the schedule is baked directly into the optimizer definition.

This exercise asks you to create a learning rate schedule and pass it to your optimizer. Optax will handle the updates automatically at each step. You will implement a warmup-cosine-decay schedule, a very common and effective schedule.

Instructions

Complete the TODO sections to define a learning rate schedule and use it in an Adam optimizer.
# @title Exercise 3: Implement a Learning Rate Schedule
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import matplotlib.pyplot as plt

# --- Using the same model and data setup from Exercise 1 ---
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

model_scheduled = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))

# --- Scheduling Hyperparameters ---
total_training_steps = 500
warmup_fraction = 0.1
peak_lr = 1e-3
final_lr = 1e-5

# TODO: Define a warmup-cosine-decay learning rate schedule.
# Hint: Use optax.warmup_cosine_decay_schedule.
# It needs an initial value, a peak value, warmup steps, and decay steps.
warmup_steps = # YOUR CODE HERE
decay_steps = # YOUR CODE HERE

lr_schedule_fn = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=peak_lr,
    warmup_steps=warmup_steps,
    decay_steps=decay_steps,
    end_value=final_lr
)


# TODO: Create an Adam optimizer that uses this schedule.
# Hint: Simply pass the schedule function as the `learning_rate` argument.
opt_scheduled = # YOUR CODE HERE


# --- Boilerplate for running the exercise ---
optimizer_scheduled = nnx.Optimizer(model_scheduled, opt_scheduled, wrt=nnx.Param)

# Training Loop
print("Starting training with scheduled LR...")
lrs = []
for i in range(total_training_steps):
  # The LR is updated automatically inside train_step
  model_scheduled, optimizer_scheduled, loss = train_step(model_scheduled, optimizer_scheduled, x_dummy, y_dummy)
  # We can extract the current LR for plotting
  # Note: This requires the optimizer state to be on the host.
  # In a real scenario, you might not check this every step.
  current_lr = lr_schedule_fn(optimizer_scheduled.step.value)
  lrs.append(current_lr)
  if i % 50 == 0:
    print(f"Step {optimizer_scheduled.step.value}, Loss: {loss:.5f}, LR: {current_lr:.6f}")
print("Scheduled LR training finished.")

# Plot the learning rate over time
plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.title("Learning Rate Schedule")
plt.xlabel("Training Step")
plt.ylabel("Learning Rate")
plt.grid(True)
plt.show()
# @title Solution 3
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import matplotlib.pyplot as plt

# --- Using the same model and data setup from Exercise 1 ---
key = jax.random.key(42)
key_model, key_data = jax.random.split(key)
din, dmid, dout = 10, 20, 5
x_dummy = jax.random.normal(key_data, (32, din))
y_dummy = jax.random.normal(key_data, (32, dout))

model_scheduled = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))

# --- Scheduling Hyperparameters ---
total_training_steps = 500
warmup_fraction = 0.1
peak_lr = 1e-3
final_lr = 1e-5

# Define a warmup-cosine-decay learning rate schedule
warmup_steps = int(total_training_steps * warmup_fraction)
decay_steps = total_training_steps - warmup_steps

lr_schedule_fn = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=peak_lr,
    warmup_steps=warmup_steps,
    decay_steps=decay_steps,
    end_value=final_lr
)

# Create an Adam optimizer that uses this schedule
opt_scheduled = optax.adam(learning_rate=lr_schedule_fn)

# --- Boilerplate for running the exercise ---
optimizer_scheduled = nnx.Optimizer(model_scheduled, opt_scheduled, wrt=nnx.Param)

# Training Loop
print("Starting training with scheduled LR...")
lrs = []
for i in range(total_training_steps):
  model_scheduled, optimizer_scheduled, loss = train_step(model_scheduled, optimizer_scheduled, x_dummy, y_dummy)
  current_lr = lr_schedule_fn(optimizer_scheduled.step.value)
  lrs.append(current_lr)
  if i % 50 == 0:
    print(f"Step {optimizer_scheduled.step.value}, Loss: {loss:.5f}, LR: {current_lr:.6f}")
print("Scheduled LR training finished.")

# Plot the learning rate over time
plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.title("Learning Rate Schedule")
plt.xlabel("Training Step")
plt.ylabel("Learning Rate")
plt.grid(True)
plt.show()

Exercise 5: Sharding the Model and Optimizer State

Concept: JAX provides fine-grained control over how data and model parameters are distributed across devices. This is done by explicitly annotating PyTrees (like model parameters or optimizer state) with sharding information.

In this exercise, you will:

  1. Create a 2D device Mesh from our 8 simulated CPUs.
  2. Define a sharded MLP where the kernel of a linear layer is sharded across the 'model' axis of the mesh (Model Parallelism).
  3. Create a sharded optimizer whose state (e.g., Adam's momentum and variance vectors) automatically inherits the same sharding as the corresponding model parameters.

Instructions

Complete the TODO sections to shard your model and optimizer.
# @title Exercise 4: Sharding Model and Optimizer
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np

# Ensure we have our 8 simulated devices
if jax.device_count() != 8:
    print("Warning: This exercise expects 8 devices. Sharding may not behave as expected.")

# 1. Create a device mesh
# We'll create a 2x4 mesh, with a 'data' axis for data parallelism
# and a 'model' axis for model parallelism.
devices = np.array(jax.devices()).reshape(2, 4)
mesh = Mesh(devices, axis_names=('data', 'model'))
print("Created 2x4 device mesh:")
print(mesh)

# 2. Define a sharded model
class ShardedMLP(nnx.Module):
  def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
    # TODO: Shard the kernel of the second linear layer.
    # The goal is to split the kernel's columns across the 'model' axis.
    # This is a form of model parallelism.
    # - The first dimension (input features) should be replicated.
    # - The second dimension (output features) should be sharded.
    # - The bias should also be sharded along the 'model' axis.
    # - All other parameters can be replicated (the default).
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.relu = nnx.relu
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    # Shard linear1 fully (replicated)
    self.linear1.kernel.sharding = P(None, None) # or just P()
    self.linear1.bias.sharding = P(None) # or just P()

    # Shard linear2 for model parallelism
    # YOUR CODE HERE - Replicate rows, shard columns
    # YOUR CODE HERE - Shard the bias vector

  def __call__(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
    return x

# 3. Create sharded model and optimizer within the mesh context
@nnx.jit
def create_sharded_model_and_optimizer():
  key = jax.random.key(0)
  model = ShardedMLP(16, 32, 64, rngs=nnx.Rngs(key))
  optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

  # The sharding annotations on the model are automatically picked up.
  # Now, we need to ensure the optimizer state gets the same shardings.
  # nnx.Optimizer automatically infers this from the model's parameters!
  # We just need to use jax.lax.with_sharding_constraint to enforce it
  # during JIT compilation.

  # Shard model state based on annotations
  model_state = nnx.state(model)
  model_shardings = nnx.spmd.get_partition_spec(model_state)
  sharded_model_state = jax.lax.with_sharding_constraint(model_state, model_shardings)
  nnx.update(model, sharded_model_state)

  # TODO: Shard the optimizer state.
  # The process is identical to sharding the model, but you need to filter
  # for the optimizer's state using nnx.optimizer.OptState.
  # YOUR CODE HERE
  opt_shardings = nnx.spmd.get_partition_spec(opt_state_to_shard)
  sharded_opt_state = jax.lax.with_sharding_constraint(
      opt_state_to_shard, opt_shardings
  )
  nnx.update(optimizer, sharded_opt_state)

  return model, optimizer

# Run the creation function within the mesh context manager
with mesh:
  sharded_model, sharded_optimizer = create_sharded_model_and_optimizer()


# --- Verification ---
print("\n--- Verifying Shardings ---")
# Get the sharded state back from the JIT call
final_model_state = nnx.state(sharded_model)
final_opt_state = nnx.state(sharded_optimizer, nnx.optimizer.OptState)

# Check the sharding of the second linear layer's kernel in the model
l2_kernel_sharding = final_model_state['layers']['1']['kernel'].sharding
print(f"\nModel's linear2.kernel sharding: {l2_kernel_sharding}")
assert l2_kernel_sharding == NS(None, 'model')

# Check the sharding of the corresponding momentum (m) in the optimizer state
# The optimizer state PyTree mirrors the parameter PyTree structure.
adam_state = final_opt_state['opt_state'][1] # (trace_state, adam_state)
l2_kernel_momentum_sharding = adam_state.m['layers']['1']['kernel'].sharding
print(f"Optimizer's momentum for linear2.kernel sharding: {l2_kernel_momentum_sharding}")
assert l2_kernel_momentum_sharding == NS(None, 'model')

print("\nSuccessfully verified that optimizer state sharding matches model parameter sharding.")
# @title Solution 4
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np

# Ensure we have our 8 simulated devices
if jax.device_count() != 8:
    print("Warning: This exercise expects 8 devices. Sharding may not behave as expected.")

# 1. Create a device mesh
devices = np.array(jax.devices()).reshape(2, 4)
mesh = Mesh(devices, axis_names=('data', 'model'))
print("Created 2x4 device mesh:")
print(mesh)

# 2. Define a sharded model
class ShardedMLP(nnx.Module):
  def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.relu = nnx.relu
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    # Shard linear1 fully (replicated) - this is often the default
    self.linear1.kernel.sharding = P() # Replicated on all axes
    self.linear1.bias.sharding = P()   # Replicated on all axes

    # Shard linear2 for model parallelism
    # Shard the output dimension of the kernel and the bias
    self.linear2.kernel.sharding = P(None, 'model') # Replicate rows, shard columns
    self.linear2.bias.sharding = P('model')         # Shard the bias vector

  def __call__(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
    return x

# 3. Create sharded model and optimizer within the mesh context
@nnx.jit
def create_sharded_model_and_optimizer():
  key = jax.random.key(0)
  model = ShardedMLP(16, 32, 64, rngs=nnx.Rngs(key))
  optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

  # Shard model state based on annotations
  model_state = nnx.state(model)
  model_shardings = nnx.spmd.get_partition_spec(model_state)
  sharded_model_state = jax.lax.with_sharding_constraint(model_state, model_shardings)
  nnx.update(model, sharded_model_state)

  # Shard the optimizer state
  # Filter for the optimizer's state (step and Optax's internal state)
  opt_state_to_shard = nnx.state(optimizer, nnx.optimizer.OptState)
  # Infer the sharding specification from the parameter shardings
  opt_shardings = nnx.spmd.get_partition_spec(opt_state_to_shard)
  # Apply the sharding constraint
  sharded_opt_state = jax.lax.with_sharding_constraint(
      opt_state_to_shard, opt_shardings
  )
  nnx.update(optimizer, sharded_opt_state)

  return model, optimizer

# Run the creation function within the mesh context manager
with mesh:
  sharded_model, sharded_optimizer = create_sharded_model_and_optimizer()


# --- Verification ---
print("\n--- Verifying Shardings ---")
# Get the sharded state back from the JIT call
final_model_state = nnx.state(sharded_model)
final_opt_state = nnx.state(sharded_optimizer, nnx.optimizer.OptState)

# Check the sharding of the second linear layer's kernel in the model
l2_kernel_sharding = final_model_state['linear2']['kernel'].sharding
print(f"\nModel's linear2.kernel sharding: {l2_kernel_sharding}")
assert l2_kernel_sharding == P(None, 'model')

# Check the sharding of the corresponding momentum (m) in the optimizer state
# The optimizer state PyTree mirrors the parameter PyTree structure.
# For optax.adam, the state is a tuple of (trace_state, adam_state).
# We look inside the AdamState.
adam_state = final_opt_state['opt_state'][0]
l2_kernel_momentum_sharding = adam_state.mu['linear2']['kernel'].sharding
print(f"Optimizer's momentum for linear2.kernel sharding: {l2_kernel_momentum_sharding}")
assert l2_kernel_momentum_sharding == P(None, 'model')

print("\nSuccessfully verified that optimizer state sharding matches model parameter sharding.")