使用 Flax NNX 和 Orbax 的检查点入门

欢迎来到本次动手练习!我们会学习如何保存和加载 JAX/Flax NNX 模型——这是任何严肃的机器学习项目都必须掌握的技能。

为什么要做检查点?

深度学习模型训练很耗时,检查点能够帮助你:

  • 保存训练进度(模型参数、优化器状态),中断后可以继续。
  • 在不同阶段保留模型,用于分析或推理。
  • 为长时间训练增加容错能力。

Flax NNX 快速回顾

  • 有状态的模块:NNX 模块就是保存自身状态(参数等属性)的 Python 类,如果你熟悉 PyTorch 会感觉很直观。
  • nnx.Module:创建这些有状态组件的基类。
  • nnx.Variable:像 nnx.Param、nnx.BatchStat 这样的变量类型用来声明可学习参数或其他状态。
  • nnx.State:一个 JAX Pytree(类似嵌套字典),保存模块中所有 nnx.Variable 的取值,也是 Orbax 读写的对象。

函数式桥梁

  • nnx.split(module):把模块拆成静态结构(GraphDef)和动态状态(nnx.State),方便取出要保存的状态。
  • nnx.merge(graphdef, state):用 GraphDef 和 nnx.State 重建模块实例,通常在恢复后使用。
  • nnx.update(module, state):就地更新已有模块的状态,同样用于恢复后的场景。

Orbax:JAX 的检查点库

Orbax 是 JAX 的标准检查点库,设计上既健壮又可扩展。

  • ocp.CheckpointManager:高层管理工具,简化训练过程中多个检查点的维护(如只保留最近 N 个版本等),下面会大量使用。
  • ocp.args:用于描述保存/恢复方式的参数命名空间(如 ocp.args.StandardSave、ocp.args.StandardRestore、ocp.args.Composite)。

开始吧!

# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import orbax.checkpoint as ocp
import optax
import os
import shutil # For cleaning up directories
import chex # For faking devices

# Suppress some JAX warnings for cleaner output in the notebook
import warnings
warnings.filterwarnings("ignore", message="No GPU/TPU found, falling back to CPU.")
warnings.filterwarnings("ignore", message="Custom node type GlobalDeviceArray is not handled by Pytree traversal.") # Orbax/NNX interactions

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Orbax version: {ocp.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Chex version: {chex.__version__}")

# --- Setup for Distributed Exercises ---
# Simulate an environment with 8 CPUs for distributed examples
# This allows us to test sharding logic even on a single-CPU Colab machine.
try:
  chex.set_n_cpu_devices(8)
except RuntimeError as e:
  print(f"Note: Could not set_n_cpu_devices (may have been set already): {e}")

print(f"Number of JAX devices available: {jax.device_count()}")
print(f"Available devices: {jax.devices()}")

# Helper function to clean up checkpoint directories
def cleanup_ckpt_dir(ckpt_dir):
  if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)
    print(f"Cleaned up checkpoint directory: {ckpt_dir}")

# Create a default checkpoint directory for exercises
CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'
if not os.path.exists(CKPT_BASE_DIR):
  os.makedirs(CKPT_BASE_DIR)

print(f"Base checkpoint directory: {CKPT_BASE_DIR}")

练习 1:基础检查点 —— 保存 nnx.State

目标:学会用 Orbax 保存一个简单 Flax NNX 模块的状态。

主题

  • 定义一个 nnx.Module。
  • 用初始参数实例化 nnx.Module。
  • 使用 nnx.split() 提取 nnx.State Pytree。
  • 配置 ocp.CheckpointManager。
  • 使用 mngr.save() 搭配 ocp.args.StandardSave 保存状态。

步骤

  1. 编写继承 nnx.Module 的 SimpleLinear 线性层。
    • __init__ 中,用 nnx.Param 声明权重矩阵和偏置向量,使用 JAX 随机函数初始化(如 jax.random.uniform、jnp.zeros),并用 nnx.Rngs 管理随机键。
    • 实现 __call__ 前向:y = x @ weight + bias。
  2. 实例化 SimpleLinear。
  3. 指定检查点保存目录。
  4. 创建 ocp.CheckpointManagerOptions(例如 maxtokeep=3)。
  5. 用目录与 options 构造 ocp.CheckpointManager。
  6. 调用 nnx.split(model) 得到 graphdef 和 statetosave。
  7. 在指定训练步(如 step 100)调用 mngr.save() 保存,statetosave 需要用 ocp.args.StandardSave 包裹。
  8. 调用 mngr.waituntilfinished() 确保异步保存完成。
  9. 最后调用 mngr.close() 关闭管理器。
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params
    # TODO: Define self.weight as an nnx.Param with shape (din, dout)
    # self.weight = ...
    # TODO: Define self.bias as an nnx.Param with shape (dout,)
    # self.bias = ...

  def __call__(self, x: jax.Array) -> jax.Array:
    # TODO: Implement the forward pass
    # return ...

# --- Instantiate the Model ---
din, dout = 10, 5
# TODO: Create an nnx.Rngs object for parameter initialization
# rngs = ...
# TODO: Instantiate SimpleLinear
# model = ...

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs

# TODO: Create CheckpointManagerOptions
# options = ...
# TODO: Instantiate CheckpointManager
# mngr = ...

# --- Split the model to get the state ---
# TODO: Split the model into graphdef and state_to_save
# _graphdef, state_to_save = ...
# Alternatively, for just the state: state_to_save = nnx.state(model)
# print(f"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}")

# --- Save the state ---
step = 100
# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.
# mngr.save(...)
# TODO: Wait for saving to complete
# mngr.wait_until_finished()

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

# TODO: Close the manager
# mngr.close()
# @title Exercise 1: Solution
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # Parameters defined using nnx.Param (a type of nnx.Variable)
    self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.bias = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x: jax.Array) -> jax.Array:
    # Parameters used directly via self.weight, self.bias
    return x @ self.weight.value + self.bias.value

# --- Instantiate the Model ---
din, dout = 10, 5
rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management
model = SimpleLinear(din=din, dout=dout, rngs=rngs)

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1)

options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)

# --- Split the model to get the state ---
_graphdef, state_to_save = nnx.split(model)
# Alternatively: state_to_save = nnx.state(model)
print(f"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}")

# --- Save the state ---
step = 100
mngr.save(step, args=ocp.args.StandardSave(state_to_save))
mngr.wait_until_finished() # Ensure save completes if async

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

mngr.close() # Clean up resources

练习 2:基础检查点 —— 恢复 nnx.State

目标:学会用 Orbax 从检查点恢复模型状态。

主题

  • 使用 nnx.eval_shape() 创建“抽象”模型模板。
  • 拆分抽象模型得到 abstract_state(由 ShapeDtypeStruct 组成的 Pytree)。
  • 结合 abstract_state 与 ocp.args.StandardRestore,使用 mngr.restore() 恢复状态。
  • 用 nnx.merge(graphdef, restored_state) 重建模型。
  • 或使用 nnx.update() 更新已有模型实例。

步骤

  1. 重新打开指向练习 1 目录的 CheckpointManager(ckptdirex1)。
  2. 编写 createabstractmodel(),返回 SimpleLinear 实例,供 nnx.eval_shape() 使用。
    • 函数里用虚拟 RNG 与输入形状,eval_shape 只关心结构和 dtype,不关心具体数值。
  3. 调用 abstractmodel = nnx.evalshape(createabstractmodel)。
  4. 拆分 abstractmodel:graphdefforrestore, abstractstate = nnx.split(abstractmodel),得到用于恢复模板的 ShapeDtypeStruct。
  5. 用 mngr.latest_step() 获取最近的检查点步数。
  6. 若存在检查点,调用 mngr.restore(steptorestore, args=ocp.args.StandardRestore(abstract_state)) 恢复状态。
  7. 用 restoredmodel = nnx.merge(graphdefforrestore, restoredstate) 重建模型。
  8. (可选)打印 restored_model.bias.value 等值进行验证。
  9. 关闭管理器。
# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)
# mngr_restore = ...

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  # TODO: Return an instance of SimpleLinear, same din/dout as before
  # return ...

# TODO: Create the abstract_model using nnx.eval_shape
# abstract_model = ...

# --- Split Abstract Model to get Abstract State Structure ---
# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state
# graphdef_for_restore, abstract_state = ...
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")


# --- Restore the State ---
# TODO: Get the latest step to restore
# step_to_restore = ...

if step_to_restore is not None:
  # TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state
  # restored_state = mngr_restore.restore(...)

  # --- Reconstruct the Model ---
  # TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state
  # restored_model = ...
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Alternative: Update an existing model instance
  # model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # nnx.update(model_to_update, restored_state)
  # print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
else:
  print("No checkpoint found to restore.")

# TODO: Close the manager
# mngr_restore.close()
# @title Exercise 2: Solution

# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1

abstract_model = nnx.eval_shape(create_abstract_model)

# --- Split Abstract Model to get Abstract State Structure ---
graphdef_for_restore, abstract_state = nnx.split(abstract_model)
# abstract_state now contains ShapeDtypeStruct leaves
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")

# --- Restore the State ---
step_to_restore = mngr_restore.latest_step()

if step_to_restore is not None:
  restored_state = mngr_restore.restore(step_to_restore,
      args=ocp.args.StandardRestore(abstract_state))
  print(f"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}")

  # --- Reconstruct the Model ---
  restored_model = nnx.merge(graphdef_for_restore, restored_state)
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)
  # print(f"Original bias (first 3 values): {model.bias.value[:3]}")
  # chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)

  # Alternative: Update an existing model instance
  model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # Initialize with different values to see update working
  model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0
  print(f"Bias before update: {model_to_update.bias.value[:3]}")
  nnx.update(model_to_update, restored_state)
  print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
  if 'model' in globals(): # Check if original model exists
    chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)
else:
  print("No checkpoint found to restore.")

mngr_restore.close()

练习 3:同时保存模型参数和优化器状态

目标:在同一个检查点里保存模型参数和优化器状态。

主题

  • 使用 nnx.Optimizer 管理模型参数与 Optax 优化器状态。
  • 提取模型参数(例如 nnx.split(model, nnx.Param))。
  • 提取完整的优化器状态(nnx.state(optimizer))。
  • 用 ocp.args.Composite 在一个检查点里保存多个命名项(参数、优化器状态)。

步骤

  1. 复用 SimpleLinear 定义,重新实例化一个模型。
  2. 创建一个 Optax 优化器(如 optax.adam(learning_rate=1e-3))。
  3. 用 nnx.Optimizer 将模型与 Optax 优化器打包。
  4. (可选)模拟几步训练以更新优化器内部状态(如动量),不用真实数据,只需更新 step 计数。
    • 通过 optimizer.step.value 访问步数,例如:optimizer.step.value += 1。
  5. 在新目录(ckptdirex3)下创建新的 CheckpointManager。
  6. 提取模型参数:graphdefparams, paramsstate = nnx.split(modelex3, nnx.Param)。注意 optimizer.model 属性已移除,直接拆分原始 model 变量。
  7. 提取完整优化器状态:optimizerstatetree = nnx.state(optimizer),其中包含内部状态(如动量)以及优化器自身的步数。
  8. 定义字典 saveitems,键为名称(如 'params'、'optimizer'),值为用 ocp.args.StandardSave() 包裹的对应 Pytree(paramsstate、optimizerstatetree)。
  9. 使用 mngr.save(step, args=ocp.args.Composite(**save_items)) 保存,步数使用优化器当前 step。
  10. 等待保存完成并关闭管理器。
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

# TODO: Create an Optax optimizer (e.g., Adam)
# tx = ...
# TODO: Create an nnx.Optimizer, wrapping the model and tx
# optimizer = ...

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure
  optimizer.step.value += 10 # Simulate 10 steps
  # In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated
  # For this exercise, just advancing step is enough to see it saved/restored.
  # Let's also change a parameter slightly to see it saved
  original_bias_val_ex3 = model_ex3.bias.value.copy()
  model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1
  print(f"Optimizer step: {optimizer.step.value}")
  print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
else:
  print("Skipping optimizer step update as structure might differ from expected nnx.Optimizer.")


# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp = ...

# --- Extract States for Saving ---
# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter
# _graphdef_params, params_state = ...
# TODO: Extract the full optimizer state tree using nnx.state()
# optimizer_state_tree = ...

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}")

# --- Save Composite State ---
current_step_val = 0
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):
  current_step_val = optimizer.step.value
else: # Fallback for safety, though nnx.Optimizer should have .step
  current_step_val = 10


# TODO: Define save_items dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardSave
# save_items = {
#     'params': ...,
#     'optimizer': ...
# }

# TODO: Save using mngr_comp.save() and ocp.args.Composite
# mngr_comp.save(...)
# TODO: Wait and close the manager
# mngr_comp.wait_until_finished()
# print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
# print(f"Available checkpoints: {mngr_comp.all_steps()}")
# mngr_comp.close()
# @title Exercise 3: Solution

# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

tx = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
optimizer.step.value += 10 # Simulate 10 steps
original_bias_val_ex3 = model_ex3.bias.value.copy()
# Simulate a parameter update that would happen during training
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")

# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))

# --- Extract States for Saving ---
# Extract model parameters (e.g., using nnx.split(model, nnx.Param))
_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)
# Extract optimizer state (nnx.state(optimizer))
optimizer_state_tree = nnx.state(optimizer)

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}")
# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables

# --- Save Composite State ---
current_step_val = optimizer.step.value # Get current step from optimizer

# Save using Composite args
save_items = {
  'params': ocp.args.StandardSave(params_state),
  'optimizer': ocp.args.StandardSave(optimizer_state_tree)
}

# Can generate args per item using orbax_utils too
mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))
mngr_comp.wait_until_finished()
print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
print(f"Available checkpoints: {mngr_comp.all_steps()}")
mngr_comp.close()

练习 4:恢复模型参数与优化器状态

目标:从组合检查点同时恢复模型参数和优化器状态。

主题

  • 用 nnx.eval_shape 为模型和优化器创建抽象版本。
  • 获取参数状态与优化器状态的抽象模板。
  • 使用 ocp.args.Composite 搭配 ocp.args.StandardRestore 恢复多个条目。
  • 实例化全新的具体模型与优化器。
  • 用 nnx.update() 把恢复的状态写回这些实例。

步骤

  1. 重新打开练习 3 的 CheckpointManager(ckptdirex3)。
  2. 定义 createabstractmodelandoptimizer():
    • 内部用 nnx.eval_shape 基于创建 lambda 生成抽象模型(如 SimpleLinear)。
    • 再用 nnx.eval_shape 创建抽象 nnx.Optimizer,传入抽象模型和新的 Optax 优化器。
    • 返回 absmodel 和 absoptimizer。
  3. 调用该函数获取 absmodel 与 absoptimizer。
  4. 获取参数的抽象状态:graphdefabsparams, absparamsstate = nnx.split(absmodel, nnx.Param)。
  5. 获取优化器的抽象状态:absoptimizerstate = nnx.state(abs_optimizer)。
  6. 找到最新需要恢复的 step。
  7. 若存在检查点,为 ocp.args.Composite 构造 restore_targets 字典,键与保存时一致('params'、'optimizer'),值为 ocp.args.StandardRestore() 包裹的抽象状态。
  8. 调用 mngrcomp.restore(step, args=ocp.args.Composite(**restoretargets)) 恢复,得到字典 restored_items。
  9. 新建“干净”的 SimpleLinear 和 nnx.Optimizer 实例。
  10. 用 nnx.update(freshmodel, restoreditems['params']) 原地更新模型。
  11. 用 nnx.update(freshoptimizer, restoreditems['optimizer']) 更新优化器。
  12. 通过检查优化器步数和某个参数值进行验证。
  13. 关闭管理器。
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp_restore = ...

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)
  # abs_model = SimpleLinear(...)

  # TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.
  # abs_opt = nnx.Optimizer(...)
  # return abs_model, abs_opt

# TODO: Call the function to get abstract model and optimizer
# abs_model_restore, abs_optimizer_restore = ...

# --- Get Abstract States ---
# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)
# _graphdef_abs_params, abs_params_state = ...
# TODO: Get abstract optimizer state from abs_optimizer_restore
# abs_optimizer_state = ...

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}")

# --- Restore Composite State ---
# TODO: Get the latest step
# step_to_restore_comp = ...

if step_to_restore_comp is not None:
  # TODO: Define restore_targets dictionary for 'params' and 'optimizer'
  # Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.
  # restore_targets = {
  #    'params': ...,
  #    'optimizer': ...
  # }
  # TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite
  # restored_items = mngr_comp_restore.restore(...)

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))
  # fresh_model = ...
  # TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance
  # fresh_optimizer = ...

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # TODO: Update fresh_model with restored_items['params'] using nnx.update()
  # nnx.update(...)
  # TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()
  # nnx.update(...)

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}")
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}")
  print(f"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}") # model_ex3 is from previous cell

  # Verification
  # chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved
  # assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved
else:
  print("No composite checkpoint found.")

# TODO: Close the manager
# mngr_comp_restore.close()
# @title Exercise 4: Solution

# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # Create abstract model
  abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)
  # Create abstract optimizer
  abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)
  return abs_model, abs_opt

abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()

# --- Get Abstract States ---
_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)
abs_optimizer_state = nnx.state(abs_optimizer_restore)

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}")

# --- Restore Composite State ---
step_to_restore_comp = mngr_comp_restore.latest_step()

if step_to_restore_comp is not None:
  restore_targets = {
    'params': ocp.args.StandardRestore(abs_params_state),
    'optimizer': ocp.args.StandardRestore(abs_optimizer_state)
  }
  restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # Create fresh instances
  fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model
  fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)
  fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # Update using restored states
  nnx.update(fresh_model, restored_items['params'])
  nnx.update(fresh_optimizer, restored_items['optimizer'])

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}") # Will be from key(2)
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}") # Should match model_ex3 bias

  # Verification (model_ex3 and optimizer are from the previous cell where they were saved)
  chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)
  assert fresh_optimizer.step.value == optimizer.step.value
  print("Verification successful: Restored model parameters and optimizer step match the saved state.")
else:
  print("No composite checkpoint found.")

mngr_comp_restore.close()

练习 5:分布式检查点 —— 保存分片状态

目标:理解如何保存分布在多设备上的模型状态,Orbax 可以高效处理已分片的 JAX 数组。

主题

  • 设置 JAX 设备 Mesh。
  • 为数组定义 PartitionSpec 以指定分片方式。
  • 在 nnx.Module 中创建分片参数:一种方式是先初始化参数,再用 jax.device_put + NamedSharding 做分片并写回;NNX 也支持在 nnx.Variable 元数据里直接标注分片。
  • 保存分片状态:只要状态 Pytree 中的 JAX 数组已分片,Orbax 会透明地完成保存。

步骤

  1. 确定设备数量并创建设备 mesh(例如使用全部设备的一维 mesh)。
  2. 修改 SimpleLinear(或创建 ShardedSimpleLinear),在 __init__ 初始化参数后进行分片。
    • 权重矩阵 (din, dout) 沿 dout 维分片(如 PartitionSpec(None, 'data'))。
    • 偏置向量 (dout,) 也沿自身维度分片(PartitionSpec('data'))。
    • 应用分片:
      • 根据 PartitionSpec 与 mesh 创建 NamedSharding。
      • 使用 jax.deviceput(paramvalue, named_sharding) 得到分片后的 JAX 数组。
      • 将这些分片数组写回 nnx.Param 的 .value。
  3. 在 mesh 上下文管理器内实例化分片模型(with mesh:),确保操作感知 mesh。
  4. 在新目录 ckptdirex5 中创建 CheckpointManager。
  5. 拆分分片模型以获取状态:graphdefsharded, shardedstatetosave = nnx.split(shardedmodel)。其中数组应为带分片信息的 jax.Array。
  6. 调用 mngr.save() 保存 shardedstateto_save;对 Orbax 而言流程与非分片相同。
  7. 等待完成并关闭。
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.
# This can happen if JAX initializes its backends before chex runs.
# Forcing a rerun of this cell or restarting runtime and running setup first might help.
print(f"Using {num_devices} devices for sharding.")
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh

    key_w, key_b = rngs.params(), rngs.params()

    # Initialize as regular JAX arrays first
    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # TODO: Define PartitionSpec for weight (shard dout across 'data' axis)
    # e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1
    # weight_pspec = ...
    # TODO: Define PartitionSpec for bias (shard along 'data' axis)
    # bias_pspec = ...

    # TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs
    # weight_sharding = NamedSharding(...)
    # bias_sharding = NamedSharding(...)

    # TODO: Shard the initial arrays using jax.device_put and the NamedSharding
    # sharded_weight_value = jax.device_put(...)
    # sharded_bias_value = jax.device_put(...)

    # TODO: Assign these sharded arrays to nnx.Param attributes
    # self.weight = nnx.Param(sharded_weight_value)
    # self.bias = nnx.Param(sharded_bias_value)

    # Alternative (more direct with nnx.Variable metadata if supported well for this case):
    # self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API
    # For this exercise, jax.device_put is explicit and clear.

  def __call__(self, x: jax.Array) -> jax.Array:
    # x is assumed to be replicated or appropriately sharded for the matmul
    # For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

# TODO: Instantiate ShardedSimpleLinear within the mesh context
# with mesh:
#   sharded_model = ...

# print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
# print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")


# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
# TODO: Instantiate CheckpointManager
# mngr_sharded_save = ...

# --- Split and Save Sharded State ---
# TODO: Split the sharded_model
# _graphdef_sharded, sharded_state_to_save = ...

# print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
# print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")

# current_step_sharded = 200
# TODO: Save the sharded_state_to_save
# mngr_sharded_save.save(...)
# TODO: Wait and close
# mngr_sharded_save.wait_until_finished()
# print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
# mngr_sharded_save.close()
# @title Exercise 5: Solution

# --- Setup JAX Mesh ---
num_devices = jax.device_count()
if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1
     print("Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.")
print(f"Using {num_devices} devices for sharding.")
# Ensure a 1D mesh for simplicity, using all available (or faked) devices.
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh # Store mesh for creating NamedSharding

    key_w, key_b = rngs.params(), rngs.params()

    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # Define PartitionSpec for sharding
    # Shard weight's second dimension (dout) across the 'data' mesh axis
    weight_pspec = PartitionSpec(None, 'data')
    # Shard bias's only dimension (dout) across the 'data' mesh axis
    bias_pspec = PartitionSpec('data',)

    # Create NamedSharding from PartitionSpec and mesh
    weight_sharding = NamedSharding(self.mesh, weight_pspec)
    bias_sharding = NamedSharding(self.mesh, bias_pspec)

    # Shard the initial arrays using jax.device_put
    # This ensures the arrays are created with the specified sharding
    sharded_weight_value = jax.device_put(initial_weight, weight_sharding)
    sharded_bias_value = jax.device_put(initial_bias, bias_sharding)

    self.weight = nnx.Param(sharded_weight_value)
    self.bias = nnx.Param(sharded_bias_value)
    # Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata
    # e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.
    # Explicit jax.device_put is also a valid way to get sharded arrays into the state.

  def __call__(self, x: jax.Array) -> jax.Array:
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

with mesh: # Operations within this context are aware of the mesh
  sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)

print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")

# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))

# --- Split and Save Sharded State ---
# The live state already contains sharded jax.Array objects
_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)

print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)

current_step_sharded = 200
# Orbax handles sharded-array saving under the hood
mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))
mngr_sharded_save.wait_until_finished()
print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
mngr_sharded_save.close()

Orbax 高级特性与最佳实践(简述)

Orbax 还有一些更高级的能力,本文不做完整练习,但需要了解:

  • 异步检查点:manager.save() 可以后台运行,程序退出前或需要立即使用检查点时调用 manager.waituntilfinished()。这样不会阻塞训练主循环,提升吞吐。本教程的示例都调用了 waituntilfinished()。
  • 原子性:CheckpointManager 确保检查点原子写入,训练中途崩溃也不会留下损坏文件,这部分由 Orbax 处理。
  • 保存非 Pytree 数据(元数据):有时需要保存训练配置、数据集迭代器、模型版本等信息。可以在 ocp.args.Composite 中使用 ocp.args.JsonSave,把字典类数据与模型 Pytree 一起保存为 JSON,恢复时用 ocp.args.JsonRestore。
  • TensorStore 后端:在超大模型或云存储场景下,Orbax 可以使用 TensorStore,对单个分片进行更高效、可并行的 I/O,通常是透明的,在某些 JAX 环境中可能默认启用。

示例概念

metadata = {'version': '1.0', 'datasetinfo': 'imagenetsplit_train'}
save_args = ocp.args.Composite(
  params=ocp.args.StandardSave(params_state),
  metadata=ocp.args.JsonSave(metadata)
)
mngr.save(step, args=save_args)

关键要点

  • Flax NNX 提供了 Python 式的有状态模型定义方式。
  • Orbax 是对 NNX State Pytrees 做检查点的标准方案。
  • 通用流程:
    • 保存:nnx.split -> mngr.save。
    • 恢复:nnx.evalshape -> 获得 abstractstate -> mngr.restore -> nnx.merge 或 nnx.update。
  • CheckpointManager 可以方便地管理多个检查点。
  • 保存多个对象时使用 ocp.args.Composite(如模型参数 + 优化器状态)。
  • 分片/分布式数据恢复时,抽象目标需要包含正确的分片信息;如果抽象状态里带有分片,StandardRestore 会负责处理。

恭喜!

你已经完成了使用 Orbax 为 Flax NNX 模型做检查点的核心流程,从基础的保存/恢复到优化器状态与分布式(分片)场景。

需要更深入的细节时,请参考官方文档:

  • Orbax: https://orbax.readthedocs.io
  • Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io
  • JAX: https://jax.readthedocs.io

继续练习,享受 JAX 带来的乐趣!

欢迎通过 https://goo.gle/jax-training-feedback 提交反馈。