Welcome! In these hands-on exercises you'll learn how to save and restore JAX/Flax NNX models—a must-have skill for any serious ML project.
Training deep models takes time. Checkpoints help you:
nnx.Module: Base class for building these stateful components.nnx.Variable: Types like nnx.Param or nnx.BatchStat declare learnable or tracked state.nnx.State: A JAX pytree (nested dict-like) holding the values of all nnx.Variables—this is what we save/load with Orbax.nnx.split(module): Separates a module into static structure (GraphDef) and dynamic state (nnx.State) so you can pull out what to save.nnx.merge(graphdef, state): Reconstructs a module from GraphDef and nnx.State, typically used when restoring.nnx.update(module, state): In-place update of an existing module's state, also handy after restore.Orbax is the standard JAX checkpointing library: robust, extensible, and designed for distributed setups.
ocp.CheckpointManager: High-level helper that simplifies managing multiple checkpoints (e.g., keep the latest N). We'll use this everywhere.ocp.args: Arg objects describing how to save/restore (ocp.args.StandardSave, StandardRestore, Composite, etc.).Let's get started!
# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import orbax.checkpoint as ocp
import optax
import os
import shutil # For cleaning up directories
import chex # For faking devices
# Suppress some JAX warnings for cleaner output in the notebook
import warnings
warnings.filterwarnings("ignore", message="No GPU/TPU found, falling back to CPU.")
warnings.filterwarnings("ignore", message="Custom node type GlobalDeviceArray is not handled by Pytree traversal.") # Orbax/NNX interactions
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Orbax version: {ocp.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Chex version: {chex.__version__}")
# --- Setup for Distributed Exercises ---
# Simulate an environment with 8 CPUs for distributed examples
# This allows us to test sharding logic even on a single-CPU Colab machine.
try:
chex.set_n_cpu_devices(8)
except RuntimeError as e:
print(f"Note: Could not set_n_cpu_devices (may have been set already): {e}")
print(f"Number of JAX devices available: {jax.device_count()}")
print(f"Available devices: {jax.devices()}")
# Helper function to clean up checkpoint directories
def cleanup_ckpt_dir(ckpt_dir):
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
print(f"Cleaned up checkpoint directory: {ckpt_dir}")
# Create a default checkpoint directory for exercises
CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'
if not os.path.exists(CKPT_BASE_DIR):
os.makedirs(CKPT_BASE_DIR)
print(f"Base checkpoint directory: {CKPT_BASE_DIR}")
nnx.Statennx.Module.nnx.split() to extract the nnx.State pytree.ocp.CheckpointManager.mngr.save() with ocp.args.StandardSave to persist the state.SimpleLinear inheriting from nnx.Module.
nnx.Param; initialize with JAX random functions (e.g., jax.random.uniform, jnp.zeros) and manage keys via nnx.Rngs.y = x @ weight + bias.SimpleLinear.ocp.CheckpointManagerOptions (e.g., max_to_keep=3).ocp.CheckpointManager with the directory and options.nnx.split(model) to get graphdef and state_to_save.mngr.save(), wrapping state_to_save with ocp.args.StandardSave.mngr.wait_until_finished() to ensure async saves complete.mngr.close() to close the manager.# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params
# TODO: Define self.weight as an nnx.Param with shape (din, dout)
# self.weight = ...
# TODO: Define self.bias as an nnx.Param with shape (dout,)
# self.bias = ...
def __call__(self, x: jax.Array) -> jax.Array:
# TODO: Implement the forward pass
# return ...
# --- Instantiate the Model ---
din, dout = 10, 5
# TODO: Create an nnx.Rngs object for parameter initialization
# rngs = ...
# TODO: Instantiate SimpleLinear
# model = ...
print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")
# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs
# TODO: Create CheckpointManagerOptions
# options = ...
# TODO: Instantiate CheckpointManager
# mngr = ...
# --- Split the model to get the state ---
# TODO: Split the model into graphdef and state_to_save
# _graphdef, state_to_save = ...
# Alternatively, for just the state: state_to_save = nnx.state(model)
# print(f"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}")
# --- Save the state ---
step = 100
# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.
# mngr.save(...)
# TODO: Wait for saving to complete
# mngr.wait_until_finished()
print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")
# TODO: Close the manager
# mngr.close()
# @title Exercise 1: Solution
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# Parameters defined using nnx.Param (a type of nnx.Variable)
self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.bias = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x: jax.Array) -> jax.Array:
# Parameters used directly via self.weight, self.bias
return x @ self.weight.value + self.bias.value
# --- Instantiate the Model ---
din, dout = 10, 5
rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management
model = SimpleLinear(din=din, dout=dout, rngs=rngs)
print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")
# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1)
options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)
# --- Split the model to get the state ---
_graphdef, state_to_save = nnx.split(model)
# Alternatively: state_to_save = nnx.state(model)
print(f"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}")
# --- Save the state ---
step = 100
mngr.save(step, args=ocp.args.StandardSave(state_to_save))
mngr.wait_until_finished() # Ensure save completes if async
print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")
mngr.close() # Clean up resources
nnx.Statennx.eval_shape().abstract_state (a pytree of ShapeDtypeStruct).abstract_state with ocp.args.StandardRestore in mngr.restore().nnx.merge(graphdef, restored_state).nnx.update().CheckpointManager that points to Exercise 1 (ckpt_dir_ex1).create_abstract_model() that returns a SimpleLinear instance for nnx.eval_shape().
eval_shape only cares about structure/dtype, not values.abstract_model = nnx.eval_shape(create_abstract_model).abstract_model: graphdef_for_restore, abstract_state = nnx.split(abstract_model) to get the ShapeDtypeStruct template.mngr.latest_step() to find the most recent checkpoint step.mngr.restore(step_to_restore, args=ocp.args.StandardRestore(abstract_state)).restored_model = nnx.merge(graphdef_for_restore, restored_state).restored_model.bias.value to verify.# Ensure the SimpleLinear class definition from Exercise 1 is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)
# mngr_restore = ...
# --- Create Abstract Model for Restoration ---
def create_abstract_model():
# Use dummy RNG key/inputs for abstract creation
# TODO: Return an instance of SimpleLinear, same din/dout as before
# return ...
# TODO: Create the abstract_model using nnx.eval_shape
# abstract_model = ...
# --- Split Abstract Model to get Abstract State Structure ---
# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state
# graphdef_for_restore, abstract_state = ...
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")
# --- Restore the State ---
# TODO: Get the latest step to restore
# step_to_restore = ...
if step_to_restore is not None:
# TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state
# restored_state = mngr_restore.restore(...)
# --- Reconstruct the Model ---
# TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state
# restored_model = ...
print(f"Model restored from step {step_to_restore}.")
# You can now use 'restored_model'
print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")
# Alternative: Update an existing model instance
# model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
# nnx.update(model_to_update, restored_state)
# print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
else:
print("No checkpoint found to restore.")
# TODO: Close the manager
# mngr_restore.close()
# @title Exercise 2: Solution
# Ensure the SimpleLinear class definition from Exercise 1 is available
# --- Re-open CheckpointManager ---
mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager
# --- Create Abstract Model for Restoration ---
def create_abstract_model():
# Use dummy RNG key/inputs for abstract creation
return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1
abstract_model = nnx.eval_shape(create_abstract_model)
# --- Split Abstract Model to get Abstract State Structure ---
graphdef_for_restore, abstract_state = nnx.split(abstract_model)
# abstract_state now contains ShapeDtypeStruct leaves
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")
# --- Restore the State ---
step_to_restore = mngr_restore.latest_step()
if step_to_restore is not None:
restored_state = mngr_restore.restore(step_to_restore,
args=ocp.args.StandardRestore(abstract_state))
print(f"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}")
# --- Reconstruct the Model ---
restored_model = nnx.merge(graphdef_for_restore, restored_state)
print(f"Model restored from step {step_to_restore}.")
# You can now use 'restored_model'
print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")
# Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)
# print(f"Original bias (first 3 values): {model.bias.value[:3]}")
# chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)
# Alternative: Update an existing model instance
model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
# Initialize with different values to see update working
model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0
print(f"Bias before update: {model_to_update.bias.value[:3]}")
nnx.update(model_to_update, restored_state)
print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
if 'model' in globals(): # Check if original model exists
chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)
else:
print("No checkpoint found to restore.")
mngr_restore.close()
nnx.Optimizer to manage model params and Optax optimizer state.nnx.split(model, nnx.Param)).nnx.state(optimizer)).ocp.args.Composite to save multiple named items (params, optimizer state) in one checkpoint.SimpleLinear definition and instantiate a new model.optax.adam(learning_rate=1e-3)).nnx.Optimizer.optimizer.step.value += 1, etc.CheckpointManager under ckpt_dir_ex3._graphdef_params, params_state = nnx.split(model_ex3, nnx.Param). (Note: optimizer.model no longer exists; split the model directly.)optimizer_state_tree = nnx.state(optimizer) (includes internal state like momentum plus the optimizer step).save_items dict with names (e.g., 'params', 'optimizer') mapped to the corresponding pytrees wrapped in ocp.args.StandardSave().mngr.save(step, args=ocp.args.Composite(**save_items)), using the optimizer's current step.# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)
# TODO: Create an Optax optimizer (e.g., Adam)
# tx = ...
# TODO: Create an nnx.Optimizer, wrapping the model and tx
# optimizer = ...
# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure
optimizer.step.value += 10 # Simulate 10 steps
# In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated
# For this exercise, just advancing step is enough to see it saved/restored.
# Let's also change a parameter slightly to see it saved
original_bias_val_ex3 = model_ex3.bias.value.copy()
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
else:
print("Skipping optimizer step update as structure might differ from expected nnx.Optimizer.")
# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp = ...
# --- Extract States for Saving ---
# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter
# _graphdef_params, params_state = ...
# TODO: Extract the full optimizer state tree using nnx.state()
# optimizer_state_tree = ...
print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}")
# --- Save Composite State ---
current_step_val = 0
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):
current_step_val = optimizer.step.value
else: # Fallback for safety, though nnx.Optimizer should have .step
current_step_val = 10
# TODO: Define save_items dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardSave
# save_items = {
# 'params': ...,
# 'optimizer': ...
# }
# TODO: Save using mngr_comp.save() and ocp.args.Composite
# mngr_comp.save(...)
# TODO: Wait and close the manager
# mngr_comp.wait_until_finished()
# print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
# print(f"Available checkpoints: {mngr_comp.all_steps()}")
# mngr_comp.close()
# @title Exercise 3: Solution
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)
tx = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)
# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
optimizer.step.value += 10 # Simulate 10 steps
original_bias_val_ex3 = model_ex3.bias.value.copy()
# Simulate a parameter update that would happen during training
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))
# --- Extract States for Saving ---
# Extract model parameters (e.g., using nnx.split(model, nnx.Param))
_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)
# Extract optimizer state (nnx.state(optimizer))
optimizer_state_tree = nnx.state(optimizer)
print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}")
# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables
# --- Save Composite State ---
current_step_val = optimizer.step.value # Get current step from optimizer
# Save using Composite args
save_items = {
'params': ocp.args.StandardSave(params_state),
'optimizer': ocp.args.StandardSave(optimizer_state_tree)
}
# Can generate args per item using orbax_utils too
mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))
mngr_comp.wait_until_finished()
print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
print(f"Available checkpoints: {mngr_comp.all_steps()}")
mngr_comp.close()
nnx.eval_shape.ocp.args.Composite with ocp.args.StandardRestore.nnx.update().CheckpointManager (ckpt_dir_ex3).create_abstract_model_and_optimizer():
nnx.eval_shape to create an abstract model (e.g., SimpleLinear).nnx.eval_shape to create an abstract nnx.Optimizer given the abstract model and a new Optax optimizer.abs_model and abs_optimizer.abs_model and abs_optimizer._graphdef_abs_params, abs_params_state = nnx.split(abs_model, nnx.Param).abs_optimizer_state = nnx.state(abs_optimizer).restore_targets dict for ocp.args.Composite with the same keys as save ('params', 'optimizer') and values wrapped in ocp.args.StandardRestore().mngr_comp.restore(step, args=ocp.args.Composite(**restore_targets)) to get restored_items.SimpleLinear and nnx.Optimizer.nnx.update(fresh_model, restored_items['params']).nnx.update(fresh_optimizer, restored_items['optimizer']).# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp_restore = ...
# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
# TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)
# abs_model = SimpleLinear(...)
# TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.
# abs_opt = nnx.Optimizer(...)
# return abs_model, abs_opt
# TODO: Call the function to get abstract model and optimizer
# abs_model_restore, abs_optimizer_restore = ...
# --- Get Abstract States ---
# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)
# _graphdef_abs_params, abs_params_state = ...
# TODO: Get abstract optimizer state from abs_optimizer_restore
# abs_optimizer_state = ...
print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}")
# --- Restore Composite State ---
# TODO: Get the latest step
# step_to_restore_comp = ...
if step_to_restore_comp is not None:
# TODO: Define restore_targets dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.
# restore_targets = {
# 'params': ...,
# 'optimizer': ...
# }
# TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite
# restored_items = mngr_comp_restore.restore(...)
# --- Instantiate and Update Concrete Model/Optimizer ---
# TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))
# fresh_model = ...
# TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance
# fresh_optimizer = ...
# Store pre-update values for comparison
pre_update_bias = fresh_model.bias.value.copy()
pre_update_opt_step = fresh_optimizer.step.value
# TODO: Update fresh_model with restored_items['params'] using nnx.update()
# nnx.update(...)
# TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()
# nnx.update(...)
print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
print(f"Fresh model bias before update (first val): {pre_update_bias[0]}")
print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}")
print(f"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}") # model_ex3 is from previous cell
# Verification
# chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved
# assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved
else:
print("No composite checkpoint found.")
# TODO: Close the manager
# mngr_comp_restore.close()
# @title Exercise 4: Solution
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)
# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
# Create abstract model
abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)
# Create abstract optimizer
abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)
return abs_model, abs_opt
abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()
# --- Get Abstract States ---
_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)
abs_optimizer_state = nnx.state(abs_optimizer_restore)
print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}")
# --- Restore Composite State ---
step_to_restore_comp = mngr_comp_restore.latest_step()
if step_to_restore_comp is not None:
restore_targets = {
'params': ocp.args.StandardRestore(abs_params_state),
'optimizer': ocp.args.StandardRestore(abs_optimizer_state)
}
restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))
# --- Instantiate and Update Concrete Model/Optimizer ---
# Create fresh instances
fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model
fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)
fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer
# Store pre-update values for comparison
pre_update_bias = fresh_model.bias.value.copy()
pre_update_opt_step = fresh_optimizer.step.value
# Update using restored states
nnx.update(fresh_model, restored_items['params'])
nnx.update(fresh_optimizer, restored_items['optimizer'])
print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
print(f"Fresh model bias before update (first val): {pre_update_bias[0]}") # Will be from key(2)
print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}") # Should match model_ex3 bias
# Verification (model_ex3 and optimizer are from the previous cell where they were saved)
chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)
assert fresh_optimizer.step.value == optimizer.step.value
print("Verification successful: Restored model parameters and optimizer step match the saved state.")
else:
print("No composite checkpoint found.")
mngr_comp_restore.close()
__init__ 初始化参数后进行分片。
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.
# This can happen if JAX initializes its backends before chex runs.
# Forcing a rerun of this cell or restarting runtime and running setup first might help.
print(f"Using {num_devices} devices for sharding.")
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh
print(mesh)
# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
self.din = din
self.dout = dout
self.mesh = mesh
key_w, key_b = rngs.params(), rngs.params()
# Initialize as regular JAX arrays first
initial_weight = jax.random.uniform(key_w, (din, dout))
initial_bias = jnp.zeros((dout,))
# TODO: Define PartitionSpec for weight (shard dout across 'data' axis)
# e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1
# weight_pspec = ...
# TODO: Define PartitionSpec for bias (shard along 'data' axis)
# bias_pspec = ...
# TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs
# weight_sharding = NamedSharding(...)
# bias_sharding = NamedSharding(...)
# TODO: Shard the initial arrays using jax.device_put and the NamedSharding
# sharded_weight_value = jax.device_put(...)
# sharded_bias_value = jax.device_put(...)
# TODO: Assign these sharded arrays to nnx.Param attributes
# self.weight = nnx.Param(sharded_weight_value)
# self.bias = nnx.Param(sharded_bias_value)
# Alternative (more direct with nnx.Variable metadata if supported well for this case):
# self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API
# For this exercise, jax.device_put is explicit and clear.
def __call__(self, x: jax.Array) -> jax.Array:
# x is assumed to be replicated or appropriately sharded for the matmul
# For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.
return x @ self.weight.value + self.bias.value
# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding
rngs_sharded = nnx.Rngs(params=jax.random.key(3))
# TODO: Instantiate ShardedSimpleLinear within the mesh context
# with mesh:
# sharded_model = ...
# print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
# print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")
# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
# TODO: Instantiate CheckpointManager
# mngr_sharded_save = ...
# --- Split and Save Sharded State ---
# TODO: Split the sharded_model
# _graphdef_sharded, sharded_state_to_save = ...
# print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
# print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# current_step_sharded = 200
# TODO: Save the sharded_state_to_save
# mngr_sharded_save.save(...)
# TODO: Wait and close
# mngr_sharded_save.wait_until_finished()
# print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
# mngr_sharded_save.close()
# @title Exercise 5: Solution
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1
print("Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.")
print(f"Using {num_devices} devices for sharding.")
# Ensure a 1D mesh for simplicity, using all available (or faked) devices.
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism
print(mesh)
# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
self.din = din
self.dout = dout
self.mesh = mesh # Store mesh for creating NamedSharding
key_w, key_b = rngs.params(), rngs.params()
initial_weight = jax.random.uniform(key_w, (din, dout))
initial_bias = jnp.zeros((dout,))
# Define PartitionSpec for sharding
# Shard weight's second dimension (dout) across the 'data' mesh axis
weight_pspec = PartitionSpec(None, 'data')
# Shard bias's only dimension (dout) across the 'data' mesh axis
bias_pspec = PartitionSpec('data',)
# Create NamedSharding from PartitionSpec and mesh
weight_sharding = NamedSharding(self.mesh, weight_pspec)
bias_sharding = NamedSharding(self.mesh, bias_pspec)
# Shard the initial arrays using jax.device_put
# This ensures the arrays are created with the specified sharding
sharded_weight_value = jax.device_put(initial_weight, weight_sharding)
sharded_bias_value = jax.device_put(initial_bias, bias_sharding)
self.weight = nnx.Param(sharded_weight_value)
self.bias = nnx.Param(sharded_bias_value)
# Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata
# e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.
# Explicit jax.device_put is also a valid way to get sharded arrays into the state.
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.weight.value + self.bias.value
# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices
rngs_sharded = nnx.Rngs(params=jax.random.key(3))
with mesh: # Operations within this context are aware of the mesh
sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)
print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")
# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))
# --- Split and Save Sharded State ---
# The live state already contains sharded jax.Array objects
_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)
print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)
current_step_sharded = 200
# Orbax handles sharded-array saving under the hood
mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))
mngr_sharded_save.wait_until_finished()
print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
mngr_sharded_save.close()
Orbax 还有一些更高级的能力,本文不做完整练习,但需要了解:
metadata = {'version': '1.0', 'datasetinfo': 'imagenetsplit_train'}
save_args = ocp.args.Composite(
params=ocp.args.StandardSave(params_state),
metadata=ocp.args.JsonSave(metadata)
)
mngr.save(step, args=save_args)
需要更深入的细节时,请参考官方文档:
继续练习,享受 JAX 带来的乐趣!
欢迎通过 https://goo.gle/jax-training-feedback 提交反馈。