欢迎来到本次动手练习!我们会学习如何保存和加载 JAX/Flax NNX 模型——这是任何严肃的机器学习项目都必须掌握的技能。
深度学习模型训练很耗时,检查点能够帮助你:
nnx.Module:创建这些有状态组件的基类。nnx.Variable:像 nnx.Param、nnx.BatchStat 这样的变量类型用来声明可学习参数或其他状态。nnx.State:一个 JAX Pytree(类似嵌套字典),保存模块中所有 nnx.Variable 的取值,也是 Orbax 读写的对象。nnx.split(module):把模块拆成静态结构(GraphDef)和动态状态(nnx.State),方便取出要保存的状态。nnx.merge(graphdef, state):用 GraphDef 和 nnx.State 重建模块实例,通常在恢复后使用。nnx.update(module, state):就地更新已有模块的状态,同样用于恢复后的场景。Orbax 是 JAX 的标准检查点库,设计上既健壮又可扩展。
ocp.CheckpointManager:高层管理工具,简化训练过程中多个检查点的维护(如只保留最近 N 个版本等),下面会大量使用。ocp.args:用于描述保存/恢复方式的参数命名空间(如 ocp.args.StandardSave、ocp.args.StandardRestore、ocp.args.Composite)。开始吧!
# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import orbax.checkpoint as ocp
import optax
import os
import shutil # For cleaning up directories
import chex # For faking devices
# Suppress some JAX warnings for cleaner output in the notebook
import warnings
warnings.filterwarnings("ignore", message="No GPU/TPU found, falling back to CPU.")
warnings.filterwarnings("ignore", message="Custom node type GlobalDeviceArray is not handled by Pytree traversal.") # Orbax/NNX interactions
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Orbax version: {ocp.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Chex version: {chex.__version__}")
# --- Setup for Distributed Exercises ---
# Simulate an environment with 8 CPUs for distributed examples
# This allows us to test sharding logic even on a single-CPU Colab machine.
try:
chex.set_n_cpu_devices(8)
except RuntimeError as e:
print(f"Note: Could not set_n_cpu_devices (may have been set already): {e}")
print(f"Number of JAX devices available: {jax.device_count()}")
print(f"Available devices: {jax.devices()}")
# Helper function to clean up checkpoint directories
def cleanup_ckpt_dir(ckpt_dir):
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
print(f"Cleaned up checkpoint directory: {ckpt_dir}")
# Create a default checkpoint directory for exercises
CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'
if not os.path.exists(CKPT_BASE_DIR):
os.makedirs(CKPT_BASE_DIR)
print(f"Base checkpoint directory: {CKPT_BASE_DIR}")
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params
# TODO: Define self.weight as an nnx.Param with shape (din, dout)
# self.weight = ...
# TODO: Define self.bias as an nnx.Param with shape (dout,)
# self.bias = ...
def __call__(self, x: jax.Array) -> jax.Array:
# TODO: Implement the forward pass
# return ...
# --- Instantiate the Model ---
din, dout = 10, 5
# TODO: Create an nnx.Rngs object for parameter initialization
# rngs = ...
# TODO: Instantiate SimpleLinear
# model = ...
print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")
# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs
# TODO: Create CheckpointManagerOptions
# options = ...
# TODO: Instantiate CheckpointManager
# mngr = ...
# --- Split the model to get the state ---
# TODO: Split the model into graphdef and state_to_save
# _graphdef, state_to_save = ...
# Alternatively, for just the state: state_to_save = nnx.state(model)
# print(f"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}")
# --- Save the state ---
step = 100
# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.
# mngr.save(...)
# TODO: Wait for saving to complete
# mngr.wait_until_finished()
print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")
# TODO: Close the manager
# mngr.close()
# @title Exercise 1: Solution
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# Parameters defined using nnx.Param (a type of nnx.Variable)
self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.bias = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x: jax.Array) -> jax.Array:
# Parameters used directly via self.weight, self.bias
return x @ self.weight.value + self.bias.value
# --- Instantiate the Model ---
din, dout = 10, 5
rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management
model = SimpleLinear(din=din, dout=dout, rngs=rngs)
print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")
# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1)
options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)
# --- Split the model to get the state ---
_graphdef, state_to_save = nnx.split(model)
# Alternatively: state_to_save = nnx.state(model)
print(f"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}")
# --- Save the state ---
step = 100
mngr.save(step, args=ocp.args.StandardSave(state_to_save))
mngr.wait_until_finished() # Ensure save completes if async
print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")
mngr.close() # Clean up resources
# Ensure the SimpleLinear class definition from Exercise 1 is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)
# mngr_restore = ...
# --- Create Abstract Model for Restoration ---
def create_abstract_model():
# Use dummy RNG key/inputs for abstract creation
# TODO: Return an instance of SimpleLinear, same din/dout as before
# return ...
# TODO: Create the abstract_model using nnx.eval_shape
# abstract_model = ...
# --- Split Abstract Model to get Abstract State Structure ---
# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state
# graphdef_for_restore, abstract_state = ...
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")
# --- Restore the State ---
# TODO: Get the latest step to restore
# step_to_restore = ...
if step_to_restore is not None:
# TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state
# restored_state = mngr_restore.restore(...)
# --- Reconstruct the Model ---
# TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state
# restored_model = ...
print(f"Model restored from step {step_to_restore}.")
# You can now use 'restored_model'
print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")
# Alternative: Update an existing model instance
# model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
# nnx.update(model_to_update, restored_state)
# print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
else:
print("No checkpoint found to restore.")
# TODO: Close the manager
# mngr_restore.close()
# @title Exercise 2: Solution
# Ensure the SimpleLinear class definition from Exercise 1 is available
# --- Re-open CheckpointManager ---
mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager
# --- Create Abstract Model for Restoration ---
def create_abstract_model():
# Use dummy RNG key/inputs for abstract creation
return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1
abstract_model = nnx.eval_shape(create_abstract_model)
# --- Split Abstract Model to get Abstract State Structure ---
graphdef_for_restore, abstract_state = nnx.split(abstract_model)
# abstract_state now contains ShapeDtypeStruct leaves
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")
# --- Restore the State ---
step_to_restore = mngr_restore.latest_step()
if step_to_restore is not None:
restored_state = mngr_restore.restore(step_to_restore,
args=ocp.args.StandardRestore(abstract_state))
print(f"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}")
# --- Reconstruct the Model ---
restored_model = nnx.merge(graphdef_for_restore, restored_state)
print(f"Model restored from step {step_to_restore}.")
# You can now use 'restored_model'
print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")
# Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)
# print(f"Original bias (first 3 values): {model.bias.value[:3]}")
# chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)
# Alternative: Update an existing model instance
model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
# Initialize with different values to see update working
model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0
print(f"Bias before update: {model_to_update.bias.value[:3]}")
nnx.update(model_to_update, restored_state)
print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
if 'model' in globals(): # Check if original model exists
chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)
else:
print("No checkpoint found to restore.")
mngr_restore.close()
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)
# TODO: Create an Optax optimizer (e.g., Adam)
# tx = ...
# TODO: Create an nnx.Optimizer, wrapping the model and tx
# optimizer = ...
# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure
optimizer.step.value += 10 # Simulate 10 steps
# In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated
# For this exercise, just advancing step is enough to see it saved/restored.
# Let's also change a parameter slightly to see it saved
original_bias_val_ex3 = model_ex3.bias.value.copy()
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
else:
print("Skipping optimizer step update as structure might differ from expected nnx.Optimizer.")
# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp = ...
# --- Extract States for Saving ---
# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter
# _graphdef_params, params_state = ...
# TODO: Extract the full optimizer state tree using nnx.state()
# optimizer_state_tree = ...
print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}")
# --- Save Composite State ---
current_step_val = 0
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):
current_step_val = optimizer.step.value
else: # Fallback for safety, though nnx.Optimizer should have .step
current_step_val = 10
# TODO: Define save_items dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardSave
# save_items = {
# 'params': ...,
# 'optimizer': ...
# }
# TODO: Save using mngr_comp.save() and ocp.args.Composite
# mngr_comp.save(...)
# TODO: Wait and close the manager
# mngr_comp.wait_until_finished()
# print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
# print(f"Available checkpoints: {mngr_comp.all_steps()}")
# mngr_comp.close()
# @title Exercise 3: Solution
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)
tx = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)
# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
optimizer.step.value += 10 # Simulate 10 steps
original_bias_val_ex3 = model_ex3.bias.value.copy()
# Simulate a parameter update that would happen during training
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))
# --- Extract States for Saving ---
# Extract model parameters (e.g., using nnx.split(model, nnx.Param))
_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)
# Extract optimizer state (nnx.state(optimizer))
optimizer_state_tree = nnx.state(optimizer)
print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}")
# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables
# --- Save Composite State ---
current_step_val = optimizer.step.value # Get current step from optimizer
# Save using Composite args
save_items = {
'params': ocp.args.StandardSave(params_state),
'optimizer': ocp.args.StandardSave(optimizer_state_tree)
}
# Can generate args per item using orbax_utils too
mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))
mngr_comp.wait_until_finished()
print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
print(f"Available checkpoints: {mngr_comp.all_steps()}")
mngr_comp.close()
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp_restore = ...
# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
# TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)
# abs_model = SimpleLinear(...)
# TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.
# abs_opt = nnx.Optimizer(...)
# return abs_model, abs_opt
# TODO: Call the function to get abstract model and optimizer
# abs_model_restore, abs_optimizer_restore = ...
# --- Get Abstract States ---
# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)
# _graphdef_abs_params, abs_params_state = ...
# TODO: Get abstract optimizer state from abs_optimizer_restore
# abs_optimizer_state = ...
print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}")
# --- Restore Composite State ---
# TODO: Get the latest step
# step_to_restore_comp = ...
if step_to_restore_comp is not None:
# TODO: Define restore_targets dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.
# restore_targets = {
# 'params': ...,
# 'optimizer': ...
# }
# TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite
# restored_items = mngr_comp_restore.restore(...)
# --- Instantiate and Update Concrete Model/Optimizer ---
# TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))
# fresh_model = ...
# TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance
# fresh_optimizer = ...
# Store pre-update values for comparison
pre_update_bias = fresh_model.bias.value.copy()
pre_update_opt_step = fresh_optimizer.step.value
# TODO: Update fresh_model with restored_items['params'] using nnx.update()
# nnx.update(...)
# TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()
# nnx.update(...)
print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
print(f"Fresh model bias before update (first val): {pre_update_bias[0]}")
print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}")
print(f"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}") # model_ex3 is from previous cell
# Verification
# chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved
# assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved
else:
print("No composite checkpoint found.")
# TODO: Close the manager
# mngr_comp_restore.close()
# @title Exercise 4: Solution
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)
# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
# Create abstract model
abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)
# Create abstract optimizer
abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)
return abs_model, abs_opt
abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()
# --- Get Abstract States ---
_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)
abs_optimizer_state = nnx.state(abs_optimizer_restore)
print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}")
# --- Restore Composite State ---
step_to_restore_comp = mngr_comp_restore.latest_step()
if step_to_restore_comp is not None:
restore_targets = {
'params': ocp.args.StandardRestore(abs_params_state),
'optimizer': ocp.args.StandardRestore(abs_optimizer_state)
}
restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))
# --- Instantiate and Update Concrete Model/Optimizer ---
# Create fresh instances
fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model
fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)
fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer
# Store pre-update values for comparison
pre_update_bias = fresh_model.bias.value.copy()
pre_update_opt_step = fresh_optimizer.step.value
# Update using restored states
nnx.update(fresh_model, restored_items['params'])
nnx.update(fresh_optimizer, restored_items['optimizer'])
print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
print(f"Fresh model bias before update (first val): {pre_update_bias[0]}") # Will be from key(2)
print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}") # Should match model_ex3 bias
# Verification (model_ex3 and optimizer are from the previous cell where they were saved)
chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)
assert fresh_optimizer.step.value == optimizer.step.value
print("Verification successful: Restored model parameters and optimizer step match the saved state.")
else:
print("No composite checkpoint found.")
mngr_comp_restore.close()
__init__ 初始化参数后进行分片。
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.
# This can happen if JAX initializes its backends before chex runs.
# Forcing a rerun of this cell or restarting runtime and running setup first might help.
print(f"Using {num_devices} devices for sharding.")
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh
print(mesh)
# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
self.din = din
self.dout = dout
self.mesh = mesh
key_w, key_b = rngs.params(), rngs.params()
# Initialize as regular JAX arrays first
initial_weight = jax.random.uniform(key_w, (din, dout))
initial_bias = jnp.zeros((dout,))
# TODO: Define PartitionSpec for weight (shard dout across 'data' axis)
# e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1
# weight_pspec = ...
# TODO: Define PartitionSpec for bias (shard along 'data' axis)
# bias_pspec = ...
# TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs
# weight_sharding = NamedSharding(...)
# bias_sharding = NamedSharding(...)
# TODO: Shard the initial arrays using jax.device_put and the NamedSharding
# sharded_weight_value = jax.device_put(...)
# sharded_bias_value = jax.device_put(...)
# TODO: Assign these sharded arrays to nnx.Param attributes
# self.weight = nnx.Param(sharded_weight_value)
# self.bias = nnx.Param(sharded_bias_value)
# Alternative (more direct with nnx.Variable metadata if supported well for this case):
# self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API
# For this exercise, jax.device_put is explicit and clear.
def __call__(self, x: jax.Array) -> jax.Array:
# x is assumed to be replicated or appropriately sharded for the matmul
# For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.
return x @ self.weight.value + self.bias.value
# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding
rngs_sharded = nnx.Rngs(params=jax.random.key(3))
# TODO: Instantiate ShardedSimpleLinear within the mesh context
# with mesh:
# sharded_model = ...
# print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
# print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")
# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
# TODO: Instantiate CheckpointManager
# mngr_sharded_save = ...
# --- Split and Save Sharded State ---
# TODO: Split the sharded_model
# _graphdef_sharded, sharded_state_to_save = ...
# print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
# print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# current_step_sharded = 200
# TODO: Save the sharded_state_to_save
# mngr_sharded_save.save(...)
# TODO: Wait and close
# mngr_sharded_save.wait_until_finished()
# print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
# mngr_sharded_save.close()
# @title Exercise 5: Solution
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1
print("Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.")
print(f"Using {num_devices} devices for sharding.")
# Ensure a 1D mesh for simplicity, using all available (or faked) devices.
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism
print(mesh)
# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
self.din = din
self.dout = dout
self.mesh = mesh # Store mesh for creating NamedSharding
key_w, key_b = rngs.params(), rngs.params()
initial_weight = jax.random.uniform(key_w, (din, dout))
initial_bias = jnp.zeros((dout,))
# Define PartitionSpec for sharding
# Shard weight's second dimension (dout) across the 'data' mesh axis
weight_pspec = PartitionSpec(None, 'data')
# Shard bias's only dimension (dout) across the 'data' mesh axis
bias_pspec = PartitionSpec('data',)
# Create NamedSharding from PartitionSpec and mesh
weight_sharding = NamedSharding(self.mesh, weight_pspec)
bias_sharding = NamedSharding(self.mesh, bias_pspec)
# Shard the initial arrays using jax.device_put
# This ensures the arrays are created with the specified sharding
sharded_weight_value = jax.device_put(initial_weight, weight_sharding)
sharded_bias_value = jax.device_put(initial_bias, bias_sharding)
self.weight = nnx.Param(sharded_weight_value)
self.bias = nnx.Param(sharded_bias_value)
# Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata
# e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.
# Explicit jax.device_put is also a valid way to get sharded arrays into the state.
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.weight.value + self.bias.value
# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices
rngs_sharded = nnx.Rngs(params=jax.random.key(3))
with mesh: # Operations within this context are aware of the mesh
sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)
print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")
# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))
# --- Split and Save Sharded State ---
# The live state already contains sharded jax.Array objects
_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)
print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)
current_step_sharded = 200
# Orbax handles sharded-array saving under the hood
mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))
mngr_sharded_save.wait_until_finished()
print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
mngr_sharded_save.close()
Orbax 还有一些更高级的能力,本文不做完整练习,但需要了解:
metadata = {'version': '1.0', 'datasetinfo': 'imagenetsplit_train'}
save_args = ocp.args.Composite(
params=ocp.args.StandardSave(params_state),
metadata=ocp.args.JsonSave(metadata)
)
mngr.save(step, args=save_args)
需要更深入的细节时,请参考官方文档:
继续练习,享受 JAX 带来的乐趣!
欢迎通过 https://goo.gle/jax-training-feedback 提交反馈。