Introduction to Checkpointing with Flax NNX and Orbax

Welcome! In these hands-on exercises you'll learn how to save and restore JAX/Flax NNX models—a must-have skill for any serious ML project.

Why checkpoint?

Training deep models takes time. Checkpoints help you:

  • Save training progress (model params, optimizer state) so you can resume after interruptions.
  • Keep snapshots at different stages for analysis or inference.
  • Add fault tolerance to long-running jobs.
  • Share artifacts across teams or hardware.
  • Enable reproducibility by pinning exact model/state versions.

Quick refresher on Flax NNX

  • Stateful modules: An NNX module is a Python class that stores its own state (parameters, stats, etc.); PyTorch users will feel at home.
  • nnx.Module: Base class for building these stateful components.
  • nnx.Variable: Types like nnx.Param or nnx.BatchStat declare learnable or tracked state.
  • nnx.State: A JAX pytree (nested dict-like) holding the values of all nnx.Variables—this is what we save/load with Orbax.

Functional bridge

  • nnx.split(module): Separates a module into static structure (GraphDef) and dynamic state (nnx.State) so you can pull out what to save.
  • nnx.merge(graphdef, state): Reconstructs a module from GraphDef and nnx.State, typically used when restoring.
  • nnx.update(module, state): In-place update of an existing module's state, also handy after restore.

Orbax: JAX's checkpointing library

Orbax is the standard JAX checkpointing library: robust, extensible, and designed for distributed setups.

  • ocp.CheckpointManager: High-level helper that simplifies managing multiple checkpoints (e.g., keep the latest N). We'll use this everywhere.
  • ocp.args: Arg objects describing how to save/restore (ocp.args.StandardSave, StandardRestore, Composite, etc.).

Let's get started!

# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import orbax.checkpoint as ocp
import optax
import os
import shutil # For cleaning up directories
import chex # For faking devices

# Suppress some JAX warnings for cleaner output in the notebook
import warnings
warnings.filterwarnings("ignore", message="No GPU/TPU found, falling back to CPU.")
warnings.filterwarnings("ignore", message="Custom node type GlobalDeviceArray is not handled by Pytree traversal.") # Orbax/NNX interactions

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Orbax version: {ocp.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Chex version: {chex.__version__}")

# --- Setup for Distributed Exercises ---
# Simulate an environment with 8 CPUs for distributed examples
# This allows us to test sharding logic even on a single-CPU Colab machine.
try:
  chex.set_n_cpu_devices(8)
except RuntimeError as e:
  print(f"Note: Could not set_n_cpu_devices (may have been set already): {e}")

print(f"Number of JAX devices available: {jax.device_count()}")
print(f"Available devices: {jax.devices()}")

# Helper function to clean up checkpoint directories
def cleanup_ckpt_dir(ckpt_dir):
  if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)
    print(f"Cleaned up checkpoint directory: {ckpt_dir}")

# Create a default checkpoint directory for exercises
CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'
if not os.path.exists(CKPT_BASE_DIR):
  os.makedirs(CKPT_BASE_DIR)

print(f"Base checkpoint directory: {CKPT_BASE_DIR}")

Exercise 1: Basic checkpoint — save nnx.State

Goal: Save the state of a simple Flax NNX module with Orbax.

Topics

  • Define an nnx.Module.
  • Instantiate the module with initial parameters.
  • Use nnx.split() to extract the nnx.State pytree.
  • Configure ocp.CheckpointManager.
  • Call mngr.save() with ocp.args.StandardSave to persist the state.

Steps

  1. Implement SimpleLinear inheriting from nnx.Module.
    • In __init__, declare weight matrix and bias vector as nnx.Param; initialize with JAX random functions (e.g., jax.random.uniform, jnp.zeros) and manage keys via nnx.Rngs.
    • Implement __call__: y = x @ weight + bias.
  2. Instantiate SimpleLinear.
  3. Pick a checkpoint directory.
  4. Create ocp.CheckpointManagerOptions (e.g., max_to_keep=3).
  5. Construct ocp.CheckpointManager with the directory and options.
  6. Call nnx.split(model) to get graphdef and state_to_save.
  7. At a chosen training step (e.g., step 100) call mngr.save(), wrapping state_to_save with ocp.args.StandardSave.
  8. Call mngr.wait_until_finished() to ensure async saves complete.
  9. Finally call mngr.close() to close the manager.
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params
    # TODO: Define self.weight as an nnx.Param with shape (din, dout)
    # self.weight = ...
    # TODO: Define self.bias as an nnx.Param with shape (dout,)
    # self.bias = ...

  def __call__(self, x: jax.Array) -> jax.Array:
    # TODO: Implement the forward pass
    # return ...

# --- Instantiate the Model ---
din, dout = 10, 5
# TODO: Create an nnx.Rngs object for parameter initialization
# rngs = ...
# TODO: Instantiate SimpleLinear
# model = ...

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs

# TODO: Create CheckpointManagerOptions
# options = ...
# TODO: Instantiate CheckpointManager
# mngr = ...

# --- Split the model to get the state ---
# TODO: Split the model into graphdef and state_to_save
# _graphdef, state_to_save = ...
# Alternatively, for just the state: state_to_save = nnx.state(model)
# print(f"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}")

# --- Save the state ---
step = 100
# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.
# mngr.save(...)
# TODO: Wait for saving to complete
# mngr.wait_until_finished()

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

# TODO: Close the manager
# mngr.close()
# @title Exercise 1: Solution
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # Parameters defined using nnx.Param (a type of nnx.Variable)
    self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.bias = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x: jax.Array) -> jax.Array:
    # Parameters used directly via self.weight, self.bias
    return x @ self.weight.value + self.bias.value

# --- Instantiate the Model ---
din, dout = 10, 5
rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management
model = SimpleLinear(din=din, dout=dout, rngs=rngs)

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1)

options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)

# --- Split the model to get the state ---
_graphdef, state_to_save = nnx.split(model)
# Alternatively: state_to_save = nnx.state(model)
print(f"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}")

# --- Save the state ---
step = 100
mngr.save(step, args=ocp.args.StandardSave(state_to_save))
mngr.wait_until_finished() # Ensure save completes if async

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

mngr.close() # Clean up resources

Exercise 2: Basic checkpoint — restore nnx.State

Goal: Restore model state from a checkpoint with Orbax.

Topics

  • Create an “abstract” model template with nnx.eval_shape().
  • Split the abstract model to get abstract_state (a pytree of ShapeDtypeStruct).
  • Combine abstract_state with ocp.args.StandardRestore in mngr.restore().
  • Rebuild the model via nnx.merge(graphdef, restored_state).
  • Alternatively update an existing instance with nnx.update().

Steps

  1. Re-open the CheckpointManager that points to Exercise 1 (ckpt_dir_ex1).
  2. Write create_abstract_model() that returns a SimpleLinear instance for nnx.eval_shape().
    • Use dummy RNG/inputs—eval_shape only cares about structure/dtype, not values.
  3. Call abstract_model = nnx.eval_shape(create_abstract_model).
  4. Split abstract_model: graphdef_for_restore, abstract_state = nnx.split(abstract_model) to get the ShapeDtypeStruct template.
  5. Use mngr.latest_step() to find the most recent checkpoint step.
  6. If one exists, call mngr.restore(step_to_restore, args=ocp.args.StandardRestore(abstract_state)).
  7. Rebuild with restored_model = nnx.merge(graphdef_for_restore, restored_state).
  8. (Optional) Print values such as restored_model.bias.value to verify.
  9. Close the manager.
# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)
# mngr_restore = ...

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  # TODO: Return an instance of SimpleLinear, same din/dout as before
  # return ...

# TODO: Create the abstract_model using nnx.eval_shape
# abstract_model = ...

# --- Split Abstract Model to get Abstract State Structure ---
# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state
# graphdef_for_restore, abstract_state = ...
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")


# --- Restore the State ---
# TODO: Get the latest step to restore
# step_to_restore = ...

if step_to_restore is not None:
  # TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state
  # restored_state = mngr_restore.restore(...)

  # --- Reconstruct the Model ---
  # TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state
  # restored_model = ...
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Alternative: Update an existing model instance
  # model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # nnx.update(model_to_update, restored_state)
  # print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
else:
  print("No checkpoint found to restore.")

# TODO: Close the manager
# mngr_restore.close()
# @title Exercise 2: Solution

# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1

abstract_model = nnx.eval_shape(create_abstract_model)

# --- Split Abstract Model to get Abstract State Structure ---
graphdef_for_restore, abstract_state = nnx.split(abstract_model)
# abstract_state now contains ShapeDtypeStruct leaves
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")

# --- Restore the State ---
step_to_restore = mngr_restore.latest_step()

if step_to_restore is not None:
  restored_state = mngr_restore.restore(step_to_restore,
      args=ocp.args.StandardRestore(abstract_state))
  print(f"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}")

  # --- Reconstruct the Model ---
  restored_model = nnx.merge(graphdef_for_restore, restored_state)
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)
  # print(f"Original bias (first 3 values): {model.bias.value[:3]}")
  # chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)

  # Alternative: Update an existing model instance
  model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # Initialize with different values to see update working
  model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0
  print(f"Bias before update: {model_to_update.bias.value[:3]}")
  nnx.update(model_to_update, restored_state)
  print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
  if 'model' in globals(): # Check if original model exists
    chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)
else:
  print("No checkpoint found to restore.")

mngr_restore.close()

Exercise 3: Save model parameters and optimizer state

Goal: Save both model parameters and optimizer state in one checkpoint.

Topics

  • Use nnx.Optimizer to manage model params and Optax optimizer state.
  • Extract model parameters (e.g., nnx.split(model, nnx.Param)).
  • Extract full optimizer state (nnx.state(optimizer)).
  • Use ocp.args.Composite to save multiple named items (params, optimizer state) in one checkpoint.

Steps

  1. Reuse the SimpleLinear definition and instantiate a new model.
  2. Create an Optax optimizer (e.g., optax.adam(learning_rate=1e-3)).
  3. Wrap model + tx with nnx.Optimizer.
  4. (Optional) Simulate a few training steps to update optimizer internals (e.g., momentum) by bumping the step counter—no real data needed.
    • Access with optimizer.step.value += 1, etc.
  5. Create a new CheckpointManager under ckpt_dir_ex3.
  6. Extract model parameters: _graphdef_params, params_state = nnx.split(model_ex3, nnx.Param). (Note: optimizer.model no longer exists; split the model directly.)
  7. Extract the full optimizer state: optimizer_state_tree = nnx.state(optimizer) (includes internal state like momentum plus the optimizer step).
  8. Define save_items dict with names (e.g., 'params', 'optimizer') mapped to the corresponding pytrees wrapped in ocp.args.StandardSave().
  9. Call mngr.save(step, args=ocp.args.Composite(**save_items)), using the optimizer's current step.
  10. Wait for completion and close the manager.
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

# TODO: Create an Optax optimizer (e.g., Adam)
# tx = ...
# TODO: Create an nnx.Optimizer, wrapping the model and tx
# optimizer = ...

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure
  optimizer.step.value += 10 # Simulate 10 steps
  # In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated
  # For this exercise, just advancing step is enough to see it saved/restored.
  # Let's also change a parameter slightly to see it saved
  original_bias_val_ex3 = model_ex3.bias.value.copy()
  model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1
  print(f"Optimizer step: {optimizer.step.value}")
  print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
else:
  print("Skipping optimizer step update as structure might differ from expected nnx.Optimizer.")


# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp = ...

# --- Extract States for Saving ---
# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter
# _graphdef_params, params_state = ...
# TODO: Extract the full optimizer state tree using nnx.state()
# optimizer_state_tree = ...

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}")

# --- Save Composite State ---
current_step_val = 0
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):
  current_step_val = optimizer.step.value
else: # Fallback for safety, though nnx.Optimizer should have .step
  current_step_val = 10


# TODO: Define save_items dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardSave
# save_items = {
#     'params': ...,
#     'optimizer': ...
# }

# TODO: Save using mngr_comp.save() and ocp.args.Composite
# mngr_comp.save(...)
# TODO: Wait and close the manager
# mngr_comp.wait_until_finished()
# print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
# print(f"Available checkpoints: {mngr_comp.all_steps()}")
# mngr_comp.close()
# @title Exercise 3: Solution

# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

tx = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
optimizer.step.value += 10 # Simulate 10 steps
original_bias_val_ex3 = model_ex3.bias.value.copy()
# Simulate a parameter update that would happen during training
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")

# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))

# --- Extract States for Saving ---
# Extract model parameters (e.g., using nnx.split(model, nnx.Param))
_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)
# Extract optimizer state (nnx.state(optimizer))
optimizer_state_tree = nnx.state(optimizer)

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}")
# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables

# --- Save Composite State ---
current_step_val = optimizer.step.value # Get current step from optimizer

# Save using Composite args
save_items = {
  'params': ocp.args.StandardSave(params_state),
  'optimizer': ocp.args.StandardSave(optimizer_state_tree)
}

# Can generate args per item using orbax_utils too
mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))
mngr_comp.wait_until_finished()
print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
print(f"Available checkpoints: {mngr_comp.all_steps()}")
mngr_comp.close()

Exercise 4: Restore model parameters and optimizer state

Goal: Restore both model parameters and optimizer state from a composite checkpoint.

Topics

  • Create abstract versions of model and optimizer with nnx.eval_shape.
  • Get abstract templates for parameter state and optimizer state.
  • Restore multiple items using ocp.args.Composite with ocp.args.StandardRestore.
  • Instantiate fresh concrete model and optimizer.
  • Write restored state back with nnx.update().

Steps

  1. Re-open the Exercise 3 CheckpointManager (ckpt_dir_ex3).
  2. Define create_abstract_model_and_optimizer():
    • Inside, use nnx.eval_shape to create an abstract model (e.g., SimpleLinear).
    • Then use nnx.eval_shape to create an abstract nnx.Optimizer given the abstract model and a new Optax optimizer.
    • Return abs_model and abs_optimizer.
  3. Call the function to obtain abs_model and abs_optimizer.
  4. Get abstract parameter state: _graphdef_abs_params, abs_params_state = nnx.split(abs_model, nnx.Param).
  5. Get abstract optimizer state: abs_optimizer_state = nnx.state(abs_optimizer).
  6. Find the latest step to restore.
  7. If present, build restore_targets dict for ocp.args.Composite with the same keys as save ('params', 'optimizer') and values wrapped in ocp.args.StandardRestore().
  8. Call mngr_comp.restore(step, args=ocp.args.Composite(**restore_targets)) to get restored_items.
  9. Instantiate fresh SimpleLinear and nnx.Optimizer.
  10. Update with nnx.update(fresh_model, restored_items['params']).
  11. Update with nnx.update(fresh_optimizer, restored_items['optimizer']).
  12. Verify via optimizer step and a parameter value.
  13. Close the manager.
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp_restore = ...

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)
  # abs_model = SimpleLinear(...)

  # TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.
  # abs_opt = nnx.Optimizer(...)
  # return abs_model, abs_opt

# TODO: Call the function to get abstract model and optimizer
# abs_model_restore, abs_optimizer_restore = ...

# --- Get Abstract States ---
# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)
# _graphdef_abs_params, abs_params_state = ...
# TODO: Get abstract optimizer state from abs_optimizer_restore
# abs_optimizer_state = ...

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}")

# --- Restore Composite State ---
# TODO: Get the latest step
# step_to_restore_comp = ...

if step_to_restore_comp is not None:
  # TODO: Define restore_targets dictionary for 'params' and 'optimizer'
  # Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.
  # restore_targets = {
  #    'params': ...,
  #    'optimizer': ...
  # }
  # TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite
  # restored_items = mngr_comp_restore.restore(...)

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))
  # fresh_model = ...
  # TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance
  # fresh_optimizer = ...

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # TODO: Update fresh_model with restored_items['params'] using nnx.update()
  # nnx.update(...)
  # TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()
  # nnx.update(...)

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}")
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}")
  print(f"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}") # model_ex3 is from previous cell

  # Verification
  # chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved
  # assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved
else:
  print("No composite checkpoint found.")

# TODO: Close the manager
# mngr_comp_restore.close()
# @title Exercise 4: Solution

# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # Create abstract model
  abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)
  # Create abstract optimizer
  abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)
  return abs_model, abs_opt

abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()

# --- Get Abstract States ---
_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)
abs_optimizer_state = nnx.state(abs_optimizer_restore)

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}")

# --- Restore Composite State ---
step_to_restore_comp = mngr_comp_restore.latest_step()

if step_to_restore_comp is not None:
  restore_targets = {
    'params': ocp.args.StandardRestore(abs_params_state),
    'optimizer': ocp.args.StandardRestore(abs_optimizer_state)
  }
  restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # Create fresh instances
  fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model
  fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)
  fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # Update using restored states
  nnx.update(fresh_model, restored_items['params'])
  nnx.update(fresh_optimizer, restored_items['optimizer'])

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}") # Will be from key(2)
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}") # Should match model_ex3 bias

  # Verification (model_ex3 and optimizer are from the previous cell where they were saved)
  chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)
  assert fresh_optimizer.step.value == optimizer.step.value
  print("Verification successful: Restored model parameters and optimizer step match the saved state.")
else:
  print("No composite checkpoint found.")

mngr_comp_restore.close()

练习 5:分布式检查点 —— 保存分片状态

目标:理解如何保存分布在多设备上的模型状态,Orbax 可以高效处理已分片的 JAX 数组。

主题

  • 设置 JAX 设备 Mesh。
  • 为数组定义 PartitionSpec 以指定分片方式。
  • 在 nnx.Module 中创建分片参数:一种方式是先初始化参数,再用 jax.device_put + NamedSharding 做分片并写回;NNX 也支持在 nnx.Variable 元数据里直接标注分片。
  • 保存分片状态:只要状态 Pytree 中的 JAX 数组已分片,Orbax 会透明地完成保存。

步骤

  1. 确定设备数量并创建设备 mesh(例如使用全部设备的一维 mesh)。
  2. 修改 SimpleLinear(或创建 ShardedSimpleLinear),在 __init__ 初始化参数后进行分片。
    • 权重矩阵 (din, dout) 沿 dout 维分片(如 PartitionSpec(None, 'data'))。
    • 偏置向量 (dout,) 也沿自身维度分片(PartitionSpec('data'))。
    • 应用分片:
      • 根据 PartitionSpec 与 mesh 创建 NamedSharding。
      • 使用 jax.deviceput(paramvalue, named_sharding) 得到分片后的 JAX 数组。
      • 将这些分片数组写回 nnx.Param 的 .value。
  3. 在 mesh 上下文管理器内实例化分片模型(with mesh:),确保操作感知 mesh。
  4. 在新目录 ckptdirex5 中创建 CheckpointManager。
  5. 拆分分片模型以获取状态:graphdefsharded, shardedstatetosave = nnx.split(shardedmodel)。其中数组应为带分片信息的 jax.Array。
  6. 调用 mngr.save() 保存 shardedstateto_save;对 Orbax 而言流程与非分片相同。
  7. 等待完成并关闭。
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.
# This can happen if JAX initializes its backends before chex runs.
# Forcing a rerun of this cell or restarting runtime and running setup first might help.
print(f"Using {num_devices} devices for sharding.")
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh

    key_w, key_b = rngs.params(), rngs.params()

    # Initialize as regular JAX arrays first
    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # TODO: Define PartitionSpec for weight (shard dout across 'data' axis)
    # e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1
    # weight_pspec = ...
    # TODO: Define PartitionSpec for bias (shard along 'data' axis)
    # bias_pspec = ...

    # TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs
    # weight_sharding = NamedSharding(...)
    # bias_sharding = NamedSharding(...)

    # TODO: Shard the initial arrays using jax.device_put and the NamedSharding
    # sharded_weight_value = jax.device_put(...)
    # sharded_bias_value = jax.device_put(...)

    # TODO: Assign these sharded arrays to nnx.Param attributes
    # self.weight = nnx.Param(sharded_weight_value)
    # self.bias = nnx.Param(sharded_bias_value)

    # Alternative (more direct with nnx.Variable metadata if supported well for this case):
    # self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API
    # For this exercise, jax.device_put is explicit and clear.

  def __call__(self, x: jax.Array) -> jax.Array:
    # x is assumed to be replicated or appropriately sharded for the matmul
    # For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

# TODO: Instantiate ShardedSimpleLinear within the mesh context
# with mesh:
#   sharded_model = ...

# print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
# print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")


# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
# TODO: Instantiate CheckpointManager
# mngr_sharded_save = ...

# --- Split and Save Sharded State ---
# TODO: Split the sharded_model
# _graphdef_sharded, sharded_state_to_save = ...

# print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
# print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")

# current_step_sharded = 200
# TODO: Save the sharded_state_to_save
# mngr_sharded_save.save(...)
# TODO: Wait and close
# mngr_sharded_save.wait_until_finished()
# print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
# mngr_sharded_save.close()
# @title Exercise 5: Solution

# --- Setup JAX Mesh ---
num_devices = jax.device_count()
if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1
     print("Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.")
print(f"Using {num_devices} devices for sharding.")
# Ensure a 1D mesh for simplicity, using all available (or faked) devices.
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh # Store mesh for creating NamedSharding

    key_w, key_b = rngs.params(), rngs.params()

    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # Define PartitionSpec for sharding
    # Shard weight's second dimension (dout) across the 'data' mesh axis
    weight_pspec = PartitionSpec(None, 'data')
    # Shard bias's only dimension (dout) across the 'data' mesh axis
    bias_pspec = PartitionSpec('data',)

    # Create NamedSharding from PartitionSpec and mesh
    weight_sharding = NamedSharding(self.mesh, weight_pspec)
    bias_sharding = NamedSharding(self.mesh, bias_pspec)

    # Shard the initial arrays using jax.device_put
    # This ensures the arrays are created with the specified sharding
    sharded_weight_value = jax.device_put(initial_weight, weight_sharding)
    sharded_bias_value = jax.device_put(initial_bias, bias_sharding)

    self.weight = nnx.Param(sharded_weight_value)
    self.bias = nnx.Param(sharded_bias_value)
    # Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata
    # e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.
    # Explicit jax.device_put is also a valid way to get sharded arrays into the state.

  def __call__(self, x: jax.Array) -> jax.Array:
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

with mesh: # Operations within this context are aware of the mesh
  sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)

print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")

# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))

# --- Split and Save Sharded State ---
# The live state already contains sharded jax.Array objects
_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)

print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)

current_step_sharded = 200
# Orbax handles sharded-array saving under the hood
mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))
mngr_sharded_save.wait_until_finished()
print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
mngr_sharded_save.close()

Orbax 高级特性与最佳实践(简述)

Orbax 还有一些更高级的能力,本文不做完整练习,但需要了解:

  • 异步检查点:manager.save() 可以后台运行,程序退出前或需要立即使用检查点时调用 manager.waituntilfinished()。这样不会阻塞训练主循环,提升吞吐。本教程的示例都调用了 waituntilfinished()。
  • 原子性:CheckpointManager 确保检查点原子写入,训练中途崩溃也不会留下损坏文件,这部分由 Orbax 处理。
  • 保存非 Pytree 数据(元数据):有时需要保存训练配置、数据集迭代器、模型版本等信息。可以在 ocp.args.Composite 中使用 ocp.args.JsonSave,把字典类数据与模型 Pytree 一起保存为 JSON,恢复时用 ocp.args.JsonRestore。
  • TensorStore 后端:在超大模型或云存储场景下,Orbax 可以使用 TensorStore,对单个分片进行更高效、可并行的 I/O,通常是透明的,在某些 JAX 环境中可能默认启用。

示例概念

metadata = {'version': '1.0', 'datasetinfo': 'imagenetsplit_train'}
save_args = ocp.args.Composite(
  params=ocp.args.StandardSave(params_state),
  metadata=ocp.args.JsonSave(metadata)
)
mngr.save(step, args=save_args)

关键要点

  • Flax NNX 提供了 Python 式的有状态模型定义方式。
  • Orbax 是对 NNX State Pytrees 做检查点的标准方案。
  • 通用流程:
    • 保存:nnx.split -> mngr.save。
    • 恢复:nnx.evalshape -> 获得 abstractstate -> mngr.restore -> nnx.merge 或 nnx.update。
  • CheckpointManager 可以方便地管理多个检查点。
  • 保存多个对象时使用 ocp.args.Composite(如模型参数 + 优化器状态)。
  • 分片/分布式数据恢复时,抽象目标需要包含正确的分片信息;如果抽象状态里带有分片,StandardRestore 会负责处理。

恭喜!

你已经完成了使用 Orbax 为 Flax NNX 模型做检查点的核心流程,从基础的保存/恢复到优化器状态与分布式(分片)场景。

需要更深入的细节时,请参考官方文档:

  • Orbax: https://orbax.readthedocs.io
  • Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io
  • JAX: https://jax.readthedocs.io

继续练习,享受 JAX 带来的乐趣!

欢迎通过 https://goo.gle/jax-training-feedback 提交反馈。