使用 Grain 高效加载数据:JAX/Flax NNX 练习

欢迎!这个 Colab 笔记本包含练习,帮助你学习谷歌的 Grain 库,在 JAX 中实现高效数据加载。这些练习面向熟悉 PyTorch、正在探索 JAX 生态(包括全新的 Flax NNX API)的开发者。

本笔记本目标:
  • 理解 Grain 的核心组件:DataSource、Sampler 和 Operations。
  • 学习如何使用 grain.DataLoader 进行顺序和并行的数据加载。
  • 实现自定义数据变换。
  • 探索分布式训练场景下的数据分片。
  • 了解 Grain 如何融入一个概念性的 JAX/Flax NNX 训练循环。
  • 学习如何 checkpoint 数据迭代器状态以提高可复现性。
  • 模拟多设备环境:

    为了在通常只提供单个 CPU/GPU 的 Colab 中更好地演示并行和分片概念,笔记本一开始就配置 JAX 模拟 8 个 CPU 设备,使用 XLA_FLAGSchex.set_n_cpu_devices(8) 实现。

    让我们开始吧!请先运行下一个单元来完成环境配置。
# Environment Setup
# This cell configures the environment to simulate multiple CPU devices
# and installs necessary libraries.
# IMPORTANT: RUN THIS CELL FIRST. If you encounter issues with JAX device
# counts later, try 'Runtime -> Restart runtime' in the Colab menu
# and run this cell again before any others.
import os

# Configure JAX to see 8 virtual CPU devices.
# This must be done before JAX is imported for the first time in a session.
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

# Install libraries
!pip install -q jax-ai-stack==2025.9.3
print("Libraries installed.")

# Now, import chex and attempt to set_n_cpu_devices.
# This must be called after setting XLA_FLAGS and before JAX initializes its backends.
import chex

try:
  chex.set_n_cpu_devices(8)
  print("chex.set_n_cpu_devices(8) called successfully.")
except RuntimeError as e:
  print(f"Note on chex.set_n_cpu_devices: {e}")
  print("This usually means JAX was already initialized. The XLA_FLAGS environment variable should still apply.")
  print("If you see issues with device counts, ensure you 'Restart runtime' and run this cell first.")

# Verify JAX device count
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX found {jax.device_count()} devices.")
print(f"JAX devices: {jax.devices()}")

if jax.device_count() != 8:
  print("\nWARNING: JAX does not see 8 devices. Parallelism/sharding exercises might not behave as expected.")
  print("Please try 'Runtime -> Restart runtime' and run this setup cell again first.")

# Common imports for the exercises
import jax.numpy as jnp
import numpy as np
import grain.python as grain # Main Grain API
from flax import nnx # Flax NNX API
import time # For simulating work and observing performance
import copy # For checkpointing example
from typing import Dict, Any, List # For type hints
import dataclasses # For ShardOptions if needed manually
import functools # For functools.partial
print("Imports complete. Setup finished.")

Grain 简介

正如讲座中强调的那样,JAX 在数值计算上非常快,尤其在加速器上。然而,如果数据加载效率低下,速度会被拖慢。标准 Python 数据加载常常受限于 I/O、CPU 密集的变换以及全局解释器锁(GIL)。

Grain 是谷歌为 JAX 打造的高性能数据加载方案,核心目标是:
  • 速度: 通过多进程、共享内存和预取实现。
  • 确定性: 确保实验可复现。
  • 灵活且简洁: 声明式定义数据管线。
  • 聚焦 JAX 生态: 与分布式分片等概念集成。
  • 概念上,Grain 的 DataLoader 类似于 PyTorch 的 torch.utils.data.DataLoader,负责数据读取、变换、批处理和并行。

    grain.DataLoader API 的核心组件:
    1. DataSource:提供访问单条原始数据记录的能力(需要实现 lengetitem)。
    2. Sampler:决定记录的加载顺序,并为随机操作提供种子,确保可复现性。
    3. Operations:一系列顺序执行的变换(如增强、过滤、批处理)。

    下面开始练习!

    ---

    练习 1:构建你的第一个 grain.DataLoader(顺序模式)

    目标: 熟悉基础组件:DataSourceIndexSampler、简单的 MapTransform,以及顺序模式(worker_count=0)下的 grain.DataLoader步骤:
    1. 定义 MySource,自定义的 RandomAccessDataSource
    * init:保存 num_records。 * len:返回 num_records。 * getitem:给定 idx,返回字典 {'image': image_array, 'label': label_int} image_array 应是形状 (32, 32, 3)dtype=np.uint8 的 NumPy 数组,值可依赖 idx(例如 np.ones(...) * (idx % 255))。 * label_int 是整数(例如 idx % 10)。 * 处理多轮次时的索引回绕:idx = idx % self.num_records
    1. 实例化 MySource
    2. 创建一个 IndexSampler,打乱顺序,跑 1 个 epoch,并使用固定 seed。
    3. 定义 operations 列表:
    * 继承自 grain.MapTransformConvertToFloat 类,将 image 转为 np.float32 并归一化到 [0, 1]。 * 使用 grain.Batch 将样本组成 64 的批次,舍弃尾部不足。
    1. 使用 worker_count=0(调试/顺序模式)实例化 grain.DataLoader
    * 因为 MySource 是内存数据,使用 read_options=grain.ReadOptions(num_threads=0) 关闭 Grain 的内部读取线程。
    1. 遍历 DataLoader,取第一批数据并打印图像和标签的形状。
# @title Exercise 1: Student Code
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
      # TODO: Return the total number of records
      # YOUR CODE HERE
      return 0 # Replace this

  def __getitem__(self, idx: int) -> Dict[str, Any]:
      # TODO: Handle potential index wrap-around for multiple epochs
      # effective_idx = ...
      # YOUR CODE HERE
      effective_idx = idx # Replace this

      # TODO: Simulate loading data: an image and a label
      # image = np.ones(...) * (effective_idx % 255)
      # label = effective_idx % 10
      # YOUR CODE HERE
      image = np.zeros((32,32,3), dtype=np.uint8) # Replace this
      label = 0 # Replace this
      return {'image': image, 'label': label}

# 2. Instantiate MySource
# TODO: Create an instance of MySource
# source = ...
# YOUR CODE HERE
source = None # Replace this

# 3. Create an IndexSampler
# TODO: Create an IndexSampler that shuffles, runs for 1 epoch, and uses seed 42.
# num_records should be len(source).
# index_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
index_sampler = None # Replace this

# 4. Define Operations
# TODO: Define ConvertToFloat transform
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # TODO: Convert 'image' to float32 and normalize to [0, 1].
    # Keep 'label' as is.
    # YOUR CODE HERE
    image = features['image'] # Replace this
    return {'image': image.astype(np.float32) / 255.0, 'label': features['label']}

# TODO: Create a list of transformations: ConvertToFloat instance, then grain.Batch
batch_size = 64, drop_remainder = True
# transformations = [...]
# YOUR CODE HERE
transformations = [] # Replace this

# 5. Instantiate DataLoader
# TODO: Create a DataLoader with worker_count=0 and appropriate read_options
# data_loader_sequential = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_sequential = None # Replace this

# 6. Iterate and print batch info
if data_loader_sequential:
  print("DataLoader configured sequentially.")
  data_iterator_seq = iter(data_loader_sequential)
  try:
    first_batch_seq = next(data_iterator_seq)
    print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}")
    print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}")
    # Example: Check a value from the first image of the first batch
    print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
    print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
  except StopIteration:
    print("Sequential DataLoader is empty or exhausted.")
else:
  print("Sequential DataLoader not configured yet.")
# @title Exercise 1: Solution
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records # Handle wrap-around
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

# 2. Instantiate MySource
source = MySource(num_records=1000)
print(f"DataSource created with {len(source)} records.")

# 3. Create an IndexSampler
index_sampler = grain.IndexSampler(
    num_records=len(source),
    shard_options=grain.NoSharding(), # No sharding for this exercise
    shuffle=True,
    num_epochs=1, # Run for 1 epoch
    seed=42
    )
print("IndexSampler created.")

# 4. Define Operations
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # Convert 'image' to float32 and normalize to [0, 1].
    # Keep 'label' as is.
    image = features['image'].astype(np.float32) / 255.0
    return {'image': image, 'label': features['label']}

transformations = [
    ConvertToFloat(),
    grain.Batch(batch_size=64, drop_remainder=True)
    ]
print("Transformations defined.")

# 5. Instantiate DataLoader
data_loader_sequential = grain.DataLoader(
    data_source=source,
    operations=transformations,
    sampler=index_sampler,
    worker_count=0, # Sequential mode
    shard_options=grain.NoSharding(), # Explicitly no sharding for this loader instance
    read_options=grain.ReadOptions(num_threads=0) # Dataset is in-memory
)

# 6. Iterate and print batch info
if data_loader_sequential:
  print("DataLoader configured sequentially.")
  data_iterator_seq = iter(data_loader_sequential)
  try:
    first_batch_seq = next(data_iterator_seq)
    print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}") # Expected: (64, 32, 32, 3)
    print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}") # Expected: (64,)
    # Example: Check a value from the first image of the first batch
    print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
    print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
  except StopIteration:
    print("Sequential DataLoader is empty or exhausted.")
else:
  print("Sequential DataLoader not configured yet.")

---

练习 2:通过 worker_count 开启并行

目标: 理解 worker_count > 0 如何启用多进程以加速数据加载。 步骤:
  1. 复用 MySourceIndexSampler(或新建一个,比如 num_epochs=None 的无限 epoch)、以及练习 1 的 transformations
  2. 为了更好地观察并行的收益,稍微修改 MySource:在 getitem 里加入一次 time.sleep(0.01)(10ms),模拟每条样本的 I/O 或 CPU 开销。
  3. 新建一个 grain.DataLoader(例如 data_loader_parallel),这次将 worker_count 设为大于 0 的值(如 2 或 4)。记得我们在模拟 8 个 CPU。
  4. 迭代拿到第一批数据并打印形状信息。
  5. (可选)对比顺序 loader 和并行 loader 获取 10 个批次所需的时间。加入 time.sleep 后,并行 loader 应该更快。
关于序列化的说明:worker_count > 0 时,Grain 会使用多进程。这意味着所有组件(DataSource、Sampler、Operations 以及自定义变换实例)都必须能被 Python 的 pickle 模块序列化。一般的类和函数没问题,但要避免复杂的闭包或不可序列化的对象。
# @title Exercise 2: Student Code
# 1. Reuse/Recreate components (DataSource with simulated work, Sampler, Operations)
# TODO: Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    # YOUR CODE HERE
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    # TODO: Add time.sleep(0.01) to simulate work
    # YOUR CODE HERE

    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

# TODO: Instantiate MySourceWithWork
# source_with_work = ...
# YOUR CODE HERE
source_with_work = None # Replace this

# TODO: Create a new IndexSampler (e.g., for indefinite epochs, num_epochs=None)
# Or reuse the one from Ex1 if you reset it or it's for multiple epochs.
# For simplicity, let's create one for indefinite epochs.
# parallel_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
parallel_sampler = None # Replace this

# Transformations can be reused from Exercise 1
# transformations = [ConvertToFloat(), grain.Batch(batch_size=64, drop_remainder=True)]
# (Assuming ConvertToFloat is defined from Ex1 solution)

# 2. Instantiate DataLoader with worker_count > 0
# TODO: Set num_workers (e.g., 4)
# num_workers = ...
# YOUR CODE HERE
num_workers = 0 # Replace this

# TODO: Create data_loader_parallel
# data_loader_parallel = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_parallel = None # Replace this

# 3. Iterate and print batch info
if data_loader_parallel:
  print(f"DataLoader configured with worker_count={num_workers}.")
  data_iterator_parallel = iter(data_loader_parallel)
  try:
    first_batch_parallel = next(data_iterator_parallel)
    print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
    print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
  except StopIteration:
    print("Parallel DataLoader is empty or exhausted.")
else:
  print("Parallel DataLoader not configured yet.")

# 4. (Optional) Timing comparison
# Re-create sequential loader with MySourceWithWork for a fair comparison
if source_with_work and transformations and index_sampler: # index_sampler from Ex1
  data_loader_seq_with_work = grain.DataLoader(
      data_source=source_with_work,
      operations=transformations, # Reusing from Ex1
      sampler=index_sampler, # Reusing from Ex1 (ensure it's fresh or allows re-iteration)
      worker_count=0,
      shard_options=grain.NoSharding(),
      read_options=grain.ReadOptions(num_threads=0)
      )
  num_batches_to_test = 5 # Small number for quick test

if data_loader_seq_with_work:
    print(f"\nTiming test for {num_batches_to_test} batches:")
    # Sequential
    iterator_seq = iter(data_loader_seq_with_work)
    start_time = time.time()
    try:
      for i in range(num_batches_to_test):
        batch = next(iterator_seq)
        if i == 0: print(f"  Seq batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
    except StopIteration:
      print("Sequential loader exhausted early.")
    end_time = time.time()
    print(f"Sequential ({num_batches_to_test} batches) took: {end_time - start_time:.4f} seconds")

if data_loader_parallel:
    # Parallel
    # Ensure sampler is fresh for parallel loader if it was used above
    # For this optional part, let's use a fresh sampler for the parallel loader
    # to avoid StopIteration if the previous sampler was single-epoch and exhausted.
    fresh_parallel_sampler = grain.IndexSampler(
        num_records=len(source_with_work),
        shard_options=grain.NoSharding(),
        shuffle=True,
        num_epochs=None, # Indefinite
        seed=43 # Different seed or same, for this test it's about speed
    )
    data_loader_parallel_for_timing = grain.DataLoader(
        data_source=source_with_work,
        operations=transformations, # Reusing from Ex1
        sampler=fresh_parallel_sampler,
        worker_count=num_workers if num_workers > 0 else 2, # Ensure parallelism
        shard_options=grain.NoSharding(),
        read_options=grain.ReadOptions(num_threads=0)
    )
    iterator_parallel = iter(data_loader_parallel_for_timing)
    start_time = time.time()
    try:
      for i in range(num_batches_to_test):
        batch = next(iterator_parallel)
        if i == 0: print(f"  Parallel batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
    except StopIteration:
      print("Parallel loader exhausted early.")
    end_time = time.time()
    print(f"Parallel ({num_batches_to_test} batches, {num_workers if num_workers > 0 else 2} workers) took: {end_time - start_time:.4f} seconds")
else:
  print("Skipping optional timing: source_with_work, transformations, or index_sampler not defined.")
# @title Exercise 2: Solution
# 1. Reuse/Recreate components
# Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    time.sleep(0.01) # Simulate 10ms of work per item
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

source_with_work = MySourceWithWork(num_records=1000)
print(f"MySourceWithWork created with {len(source_with_work)} records.")

# Sampler for parallel loading (indefinite epochs for robust testing)
parallel_sampler = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=None, # Run indefinitely
    seed=42
    )
print("Parallel IndexSampler created.")

# Transformations can be reused from Exercise 1 solution
# Ensure ConvertToFloat is defined (it was in Ex1 solution cell)
if 'ConvertToFloat' not in globals(): # Basic check
  class ConvertToFloat(grain.MapTransform): # Redefine if not in current scope
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']}

    print("Redefined ConvertToFloat for safety.")

transformations_ex2 = [
    ConvertToFloat(),
    grain.Batch(batch_size=64, drop_remainder=True)
    ]
print("Transformations for Ex2 ready.")

# 2. Instantiate DataLoader with worker_count > 0
num_workers = 4 # Use 4 workers; JAX is configured for 8 virtual CPUs
# Max useful workers often related to num CPU cores available.
data_loader_parallel = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=parallel_sampler,
    worker_count=num_workers,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0) # Data source simulates work but is "in-memory"
    )

# 3. Iterate and print batch info
if data_loader_parallel:
  print(f"DataLoader configured with worker_count={num_workers}.")
  data_iterator_parallel = iter(data_loader_parallel)
  try:
    first_batch_parallel = next(data_iterator_parallel)
    print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
    print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
  except StopIteration:
    print("Parallel DataLoader is empty or exhausted.")
else:
  print("Parallel DataLoader not configured yet.")

# 4. (Optional) Timing comparison
# Create a fresh IndexSampler for the sequential loader for a fair comparison start
# (num_epochs=1 to match typical test for sequential pass)
seq_sampler_for_timing = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=1, # Single epoch for this timing test
    seed=42
    )

data_loader_seq_with_work = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=seq_sampler_for_timing,
    worker_count=0,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0)
    )
num_batches_to_test = 5 # Number of batches to fetch for timing
print(f"\nTiming test for {num_batches_to_test} batches (each item has 0.01s simulated work):")

# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time_seq = time.time()
try:
  for i in range(num_batches_to_test):
    batch_seq = next(iterator_seq)
    if i == 0 and num_batches_to_test > 0 : print(f" Seq batch 1 label sum: {batch_seq['label'].sum()}") # to ensure work is done
except StopIteration:
  print(f"Sequential loader exhausted before {num_batches_to_test} batches.")
end_time_seq = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time_seq - start_time_seq:.4f} seconds")

# Parallel
# Use a fresh sampler for the parallel loader for timing to ensure it's not exhausted
# and runs for enough batches.
parallel_sampler_for_timing = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=None, # Indefinite, or ensure enough for num_batches_to_test
    seed=43 # Can be same or different seed
    )

data_loader_parallel_for_timing = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=parallel_sampler_for_timing,
    worker_count=num_workers,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0)
    )

iterator_parallel_timed = iter(data_loader_parallel_for_timing)
start_time_parallel = time.time()
try:
  for i in range(num_batches_to_test):
    batch_par = next(iterator_parallel_timed)
    if i == 0 and num_batches_to_test > 0 : print(f" Parallel batch 1 label sum: {batch_par['label'].sum()}") # to ensure work is done
except StopIteration:
  print(f"Parallel loader exhausted before {num_batches_to_test} batches.")
end_time_parallel = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers} workers) took: {end_time_parallel - start_time_parallel:.4f} seconds")

if end_time_parallel - start_time_parallel < end_time_seq - start_time_seq:
  print("Parallel loading was faster, as expected!")
else:
  print("Parallel loading was not significantly faster. This might happen for very small num_batches_to_test due to overhead, or if simulated work is too little.")

---

练习 3:自定义确定性的变换(MapTransform

目标: 实现一个确定性的自定义数据变换。 步骤:
  1. 定义自定义类 OneHotEncodeLabel,继承自 grain.MapTransform
* init 方法接收 num_classes。 * map(self, features: Dict[str, Any]) 需要: * 读取传入的 features 字典。 * 将 features['label'](整数)转换为 np.float32 的 one-hot NumPy 向量,长度为 num_classes。 * 用新的 one-hot 向量更新 features['label']。 * 返回修改后的 features 字典。
  1. 复用练习 1 的 MySource(不包含 time.sleep)和 IndexSampler(或重新创建)。
  2. 创建新的 operations 列表,包含:
* 你的 OneHotEncodeLabel 实例(例如 num_classes=10,对应 MySource 中的 idx % 10)。 * ConvertToFloat 变换(如果图像还未转换)。 * grain.Batch
  1. 实例化 grain.DataLoaderworker_count 可选 0 或 >0)。
  2. 迭代获取第一批数据,打印 one-hot 标签的形状和一个示例标签向量。
# @title Exercise 3: Student Code
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
  def __init__(self, num_classes: int):
    # TODO: Store num_classes
    # YOUR CODE HERE
    self._num_classes = 0 # Replace this

  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    label = features['label']
    # TODO: Create one-hot encoded version of the label
    # one_hot_label = np.zeros(...)
    # one_hot_label[label] = 1.0
    # YOUR CODE HERE
    one_hot_label = np.array([label]) # Replace this

    features['label'] = one_hot_label
    return features

# 2. Reuse/Create DataSource and Sampler
# TODO: Instantiate MySource (from Ex1, no sleep)
# source_ex3 = ...
# YOUR CODE HERE
source_ex3 = None # Replace this

# TODO: Instantiate an IndexSampler (e.g., from Ex1, or a new one)
# sampler_ex3 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex3 = None # Replace this

# 3. Create new list of operations
# TODO: Instantiate OneHotEncodeLabel
num_classes_for_ohe = 10
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# YOUR CODE HERE
one_hot_encoder = None # Replace this

# TODO: Define transformations_ex3 list including one_hot_encoder,
# ConvertToFloat (if not already applied), and grain.Batch
# (Assuming ConvertToFloat is defined from Ex1 solution)
# transformations_ex3 = [...]
# YOUR CODE HERE
transformations_ex3 = [] # Replace this

# 4. Instantiate DataLoader
# TODO: Create data_loader_ex3
# data_loader_ex3 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex3 = None # Replace this

# 5. Iterate and print batch info
if data_loader_ex3:
  print("DataLoader with OneHotEncodeLabel configured.")
  iterator_ex3 = iter(data_loader_ex3)
  try:
    first_batch_ex3 = next(iterator_ex3)
    print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
    print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (batch_size, num_classes)
    if first_batch_ex3['label'].size > 0:
      print(f"Custom MapTransform - Example one-hot label: {first_batch_ex3['label'][0]}")
  except StopIteration:
    print("DataLoader for Ex3 is empty or exhausted.")
else:
  print("DataLoader for Ex3 not configured yet.")
# @title Exercise 3: Solution
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
  def __init__(self, num_classes: int):
    self._num_classes = num_classes

  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    label_scalar = features['label']
    one_hot_label = np.zeros(self._num_classes, dtype=np.float32)
    one_hot_label[label_scalar] = 1.0

    # Create a new dictionary to avoid modifying the input dict in place if it's reused
    # by other transforms or parts of the pipeline, though often direct modification is fine.
    # For safety and clarity, let's return a new dict or an updated copy.
    updated_features = features.copy()
    updated_features['label'] = one_hot_label
    return updated_features

# 2. Reuse/Create DataSource and Sampler
# Using MySource from Exercise 1 solution (no artificial sleep)
if 'MySource' not in globals(): # Basic check
  class MySource(grain.RandomAccessDataSource): # Redefine if not in current scope
    def __init__(self, num_records: int = 1000):
      self._num_records = num_records
    def __len__(self) -> int:
      return self._num_records

    def __getitem__(self, idx: int) -> Dict[str, Any]:
      effective_idx = idx % self._num_records
      image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
      label = effective_idx % 10
      return {'image': image, 'label': label}
  print("Redefined MySource for Ex3.")

source_ex3 = MySource(num_records=1000)
sampler_ex3 = grain.IndexSampler(
  num_records=len(source_ex3),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=1,
  seed=42
  )
print("DataSource and Sampler for Ex3 ready.")

# 3. Create new list of operations
num_classes_for_ohe = 10 # Matches idx % 10 in MySource
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)

# Ensure ConvertToFloat is defined
if 'ConvertToFloat' not in globals():
  class ConvertToFloat(grain.MapTransform):
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']} # Pass label through
print("Redefined ConvertToFloat for Ex3.")

transformations_ex3 = [
  ConvertToFloat(), # Apply first to have float images
  one_hot_encoder, # Then one-hot encode labels
  grain.Batch(batch_size=64, drop_remainder=True)
  ]
print("Transformations for Ex3 defined.")

# 4. Instantiate DataLoader
data_loader_ex3 = grain.DataLoader(
  data_source=source_ex3,
  operations=transformations_ex3,
  sampler=sampler_ex3,
  worker_count=0, # Can be > 0 as well, OneHotEncodeLabel is picklable
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )

# 5. Iterate and print batch info
if data_loader_ex3:
  print("DataLoader with OneHotEncodeLabel configured.")
  iterator_ex3 = iter(data_loader_ex3)
  try:
    first_batch_ex3 = next(iterator_ex3)
    print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
    print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (64, 10)
    if first_batch_ex3['label'].size > 0:
      print(f"Custom MapTransform - Example one-hot label (first item): {first_batch_ex3['label'][0]}")
    original_label_example = np.argmax(first_batch_ex3['label'][0])
    print(f"Custom MapTransform - Decoded original label (first item): {original_label_example}")
  except StopIteration:
    print("DataLoader for Ex3 is empty or exhausted.")
else:
  print("DataLoader for Ex3 not configured yet.")

---

练习 4:自定义随机变换(RandomMapTransform

目标: 实现一个包含随机性的自定义变换,并通过 Grain 的机制保证可复现。 步骤:
  1. 定义自定义类 RandomBrightnessAdjust,继承自 grain.RandomMapTransform
* random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any] 应当: * 接收 featuresrng(NumPy 随机数生成器)。 * 所有随机操作都要使用传入的 rng,这样在相同种子下同一条记录得到相同的“随机”增强。 * 用 rng.uniform(0.7, 1.3) 生成随机亮度因子。 * 将 features['image'](假设已是 float 且归一化)乘以该因子。 * 使用 np.clip() 将图像裁剪到 [0.0, 1.0]。 * 返回修改后的 features
  1. 复用 MySourceIndexSampler(确保设置 seed)、以及前面练习的 ConvertToFloat
  2. 创建 operations 列表,包含 ConvertToFloat、你的 RandomBrightnessAdjust 以及 grain.Batch
  3. 使用完全相同的配置(相同数据源、采样器实例或相同 seed 的采样器、相同操作列表、相同 worker_count)实例化两个 DataLoaderdl_run1dl_run2)。
  4. 遍历获取 dl_run1 的第一批数据,打印一个像素值。
  5. 如有需要(例如 num_epochs=1),重置迭代器或重建采样器,再获取 dl_run2 的第一批数据,打印同一个像素值。
  6. 验证: 两次的像素值应一致,证明随机增强可复现。
  7. (可选)为 dl_run2IndexSampler 换个 seed,观察像素值发生变化。
# @title Exercise 4: Student Code
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
  def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
    # TODO: Ensure image is float (e.g. by placing ConvertToFloat before this in ops)
    image = features['image']

    # TODO: Generate a random brightness factor using the provided rng
    # brightness_factor = rng.uniform(...)
    # YOUR CODE HERE
    brightness_factor = 1.0 # Replace this

    # TODO: Apply brightness adjustment and clip
    # adjusted_image = np.clip(...)
    # YOUR CODE HERE
    adjusted_image = image # Replace this

    # Create a new dictionary or update a copy
    updated_features = features.copy()
    updated_features['image'] = adjusted_image
    return updated_features

# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
# TODO: Instantiate MySource (from Ex1)
# source_ex4 = ...
# YOUR CODE HERE
source_ex4 = None # Replace this

# TODO: Instantiate an IndexSampler with a seed (e.g., seed=42, num_epochs=1 or None)
# sampler_ex4_seed42 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed42 = None # Replace this

# (Assuming ConvertToFloat is defined from Ex1 solution)

# 3. Create list of operations
# TODO: Instantiate RandomBrightnessAdjust
# random_brightness_adjuster = ...
# YOUR CODE HERE
random_brightness_adjuster = None # Replace this

# TODO: Define transformations_ex4 list: ConvertToFloat, random_brightness_adjuster, grain.Batch
# transformations_ex4 = [...]
# YOUR CODE HERE
transformations_ex4 = [] # Replace this

# 4. Instantiate two DataLoaders with the same config
# TODO: Create dl_run1
# dl_run1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run1 = None # Replace this

# TODO: Create dl_run2 (using the exact same sampler instance or a new one with the same seed)
# dl_run2 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run2 = None # Replace this

# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx, H, W, C
if dl_run1:
  print("--- Run 1 (seed 42) ---")
  iterator_run1 = iter(dl_run1)
  try:
    batch1_run1 = next(iterator_run1)
    value_run1 = batch1_run1['image'][pixel_to_check]
    print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
  except StopIteration:
    print("dl_run1 exhausted.")
    value_run1 = None
else:
  print("dl_run1 not configured.")
  value_run1 = None

if dl_run2:
  print("\n--- Run 2 (seed 42, same sampler) ---")
  # If sampler_ex4_seed42 was single-epoch and already used by dl_run1,
  # dl_run2 might be empty. For robust test, ensure sampler allows re-iteration
  # or use a new sampler instance with the same seed.
  # If sampler_ex4_seed42 had num_epochs=None, iter(dl_run2) is fine.
  # If num_epochs=1, you might need to re-create sampler_ex4_seed42 for dl_run2
  # or ensure dl_run1 didn't exhaust it (e.g. by not fully iterating it).
  # For this exercise, assume sampler_ex4_seed42 can be re-used or is fresh for dl_run2.
  iterator_run2 = iter(dl_run2)
  try:
    batch1_run2 = next(iterator_run2)
    value_run2 = batch1_run2['image'][pixel_to_check]
    print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")

    # 7. Verify
    if value_run1 is not None and value_run2 is not None:
      if np.allclose(value_run1, value_run2):
        print("\nSUCCESS: Pixel values are identical. Randomness is reproducible!")
      else:
        print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}")
  except StopIteration:
    print("dl_run2 exhausted. This might happen if the sampler was single-epoch and already used.")
    value_run2 = None
else:
  print("dl_run2 not configured.")

# 8. (Optional) Test with a different seed
# TODO: Create sampler_ex4_seed100 (seed=100)
# sampler_ex4_seed100 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed100 = None # Replace this

# TODO: Create dl_run3 with sampler_ex4_seed100
# dl_run3 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run3 = None # Replace this

if dl_run3:
  print("\n--- Run 3 (seed 100) ---")
  iterator_run3 = iter(dl_run3)
  try:
    batch1_run3 = next(iterator_run3)
    value_run3 = batch1_run3['image'][pixel_to_check]
    print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
    if value_run1 is not None and not np.allclose(value_run1, value_run3):
      print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new seed.")
    elif value_run1 is not None:
      print("NOTE: Pixel values are the same as Run 1. Check seed or logic.")
  except StopIteration:
    print("dl_run3 exhausted.")
else:
  print("\nOptional part (dl_run3) not configured.")
# @title Exercise 4: Solution
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
  def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
    image = features['image'] # Assumes image is already float, e.g. from ConvertToFloat
    # Generate a random brightness factor using the provided rng
    brightness_factor = rng.uniform(0.7, 1.3)

    # Apply brightness adjustment and clip
    adjusted_image = image * brightness_factor
    adjusted_image = np.clip(adjusted_image, 0.0, 1.0)

    updated_features = features.copy()
    updated_features['image'] = adjusted_image
    return updated_features

# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
if 'MySource' not in globals(): # Basic check for MySource
  class MySource(grain.RandomAccessDataSource):
    def __init__(self, num_records: int = 1000):
      self._num_records = num_records
    def __len__(self) -> int:
      return self._num_records
    def __getitem__(self, idx: int) -> Dict[str, Any]:
      effective_idx = idx % self._num_records
      image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
      label = effective_idx % 10
      return {'image': image, 'label': label}
  print("Redefined MySource for Ex4.")
source_ex4 = MySource(num_records=1000)

# Sampler with a fixed seed. num_epochs=None allows re-iteration for multiple DataLoaders.
# If num_epochs=1, the sampler instance can only be fully iterated once.
# For this test, using num_epochs=None or re-creating the sampler for each DataLoader is safest.
# Let's use num_epochs=None to allow the same sampler instance to be used.
sampler_ex4_seed42 = grain.IndexSampler(
  num_records=len(source_ex4),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=None, # Allow indefinite iteration
  seed=42
  )
print("DataSource and Sampler (seed 42) for Ex4 ready.")

if 'ConvertToFloat' not in globals(): # Basic check for ConvertToFloat
  class ConvertToFloat(grain.MapTransform):
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']}
  print("Redefined ConvertToFloat for Ex4.")

# 3. Create list of operations
random_brightness_adjuster = RandomBrightnessAdjust()
transformations_ex4 = [
  ConvertToFloat(),
  random_brightness_adjuster,
  grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex4 defined.")

# 4. Instantiate two DataLoaders with the same config
# Using worker_count > 0 to also test picklability of RandomBrightnessAdjust
num_workers_ex4 = 2
dl_run1 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed42, # Same sampler instance
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )

dl_run2 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed42, # Same sampler instance
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print(f"DataLoaders for Run1 and Run2 created with worker_count={num_workers_ex4}.")

# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx=0, H=0, W=0, C=0
print("\n--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)

try:
  batch1_run1 = next(iterator_run1)
  value_run1 = batch1_run1['image'][pixel_to_check]
  print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
  print("dl_run1 exhausted.")
  value_run1 = None

print("\n--- Run 2 (seed 42, same sampler instance) ---")
iterator_run2 = iter(dl_run2) # Gets a new iterator from the DataLoader

try:
  batch1_run2 = next(iterator_run2)
  value_run2 = batch1_run2['image'][pixel_to_check]
  print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
  # 7. Verify
  if value_run1 is not None and value_run2 is not None:
    if np.allclose(value_run1, value_run2):
      print("\nSUCCESS: Pixel values are identical. Randomness is reproducible with the same sampler instance!")
    else:
      print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}. This shouldn't happen if sampler is re-used correctly.")
except StopIteration:
  print("dl_run2 exhausted.")

# 8. (Optional) Test with a different seed
sampler_ex4_seed100 = grain.IndexSampler(
  num_records=len(source_ex4),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=None,
  seed=100 # Different seed
  )

dl_run3 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed100, # Sampler with different seed
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("\nDataLoader for Run3 (seed 100) created.")
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)

try:
  batch1_run3 = next(iterator_run3)
  value_run3 = batch1_run3['image'][pixel_to_check]
  print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
  if value_run1 is not None and not np.allclose(value_run1, value_run3):
    print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new sampler seed.")
  elif value_run1 is not None:
    print("NOTE: Pixel values are the same as Run 1. This is unexpected if seeds are different.")
except StopIteration:
  print("dl_run3 exhausted.")

---

练习 5:分布式训练中的数据分片

目标: 理解 Grain 如何处理数据分片,这是分布式训练中每个 JAX 进程获取独立数据切片的关键。 背景: 在真实的分布式 JAX 环境中,会有多个 Python 进程。每个进程会调用 jax.process_index() 获取自身 ID,调用 jax.process_count() 获取进程总数。grain.sharding.ShardByJaxProcess() 会自动使用这些值完成分片。

由于我们在单个 Colab 笔记本中(即使模拟了多个虚拟设备,也只有一个 JAX 进程),无法真正跑多个进程。因此我们手动创建 grain.ShardOptions,模拟两个不同进程的情况。

步骤:
  1. 复用之前的 MySourcetransformations(例如 ConvertToFloatgrain.Batch)。
  2. 定义 shard_count = 2
  3. 模拟进程 0:
* 创建 shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True)。 * 用 shard_options_p0 创建 IndexSamplersampler_p0),确保开启 shuffle 并使用公共 seed(如 42)。 * 使用 sampler_p0shard_options_p0 创建 DataLoaderdl_p0),并把 shard_options_p0 传递给 DataLoader。 * 遍历 dl_p0,从前几个 batch(或 num_epochs=1 时全部批次)收集所有出现的标签。
  1. 模拟进程 1:
* 创建 shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)。 * 使用 shard_options_p1(与 sampler_p0 相同的 seed)创建 IndexSamplersampler_p1)。 * 使用 sampler_p1shard_options_p1 创建 DataLoaderdl_p1)。 * 遍历 dl_p1 并收集所有出现的标签。
  1. 验证:
* 打印“进程 0”和“进程 1”拿到的唯一标签集合。 确认这两个集合基本不重叠(如果随机打乱导致边界索引偶尔重复可能有少量重合,但大部分索引应不同)。关键在于 sampler_p0sampler_p1 采样到的索引应当互斥。 * ShardOptions 中的 drop_remainder=True 确保当数据量不能被 shard_count 整除时,会丢弃部分数据以保持各分片尽量相等(取决于实现细节)。 关于 IndexSamplerDataLoadershard_options grain.DataLoadershard_options 参数是为某个 JAX 进程开启分片的主要方式。DataLoader 会让其内部的采样器(即便你传入的不是分片的采样器)遵循当前进程的全局分片设置。如果你提供的 IndexSampler 已经带分片,它的分片必须与 DataLoader 的 shard_options 兼容。实际分布式场景中,为简单清晰起见,通常给 DataLoader 传递 ShardByJaxProcess() 或手动配置的 ShardOptions
# @title Exercise 5: Student Code
# 1. Reuse DataSource and basic transformations
# TODO: Instantiate MySource (from Ex1)
# source_ex5 = ...
# YOUR CODE HERE
source_ex5 = None # Replace this

# TODO: Define basic_transformations_ex5 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# basic_transformations_ex5 = [...]
# YOUR CODE HERE
basic_transformations_ex5 = [] # Replace this

# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1 # To make collection of all labels feasible

# 3. Simulate Process 0
# TODO: Create shard_options_p0
# shard_options_p0 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p0 = None # Replace this

# TODO: Create sampler_p0. Pass shard_options_p0 to the IndexSampler.
# sampler_p0 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p0 = None # Replace this

# TODO: Create dl_p0. Pass shard_options_p0 to the DataLoader as well.
# dl_p0 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p0 = None # Replace this

labels_p0 = set()
if dl_p0:
  print("--- Simulating Process 0 ---")
  # YOUR CODE HERE: Iterate through dl_p0 and collect all unique original labels.
  # Remember that labels might be batched. You need to iterate through items in a batch.
  # For simplicity, if your MySource generates labels like idx % 10,
  # you can try to collect the indices that were sampled.
  # Or, more directly, collect the 'label' field from each item.
  # To get original indices, you might need a transform that passes index through.
  # Let's collect the 'label' values directly.
  pass # Replace with iteration logic

# 4. Simulate Process 1
# TODO: Create shard_options_p1
# shard_options_p1 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p1 = None # Replace this

# TODO: Create sampler_p1
# sampler_p1 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p1 = None # Replace this

# TODO: Create dl_p1
# dl_p1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p1 = None # Replace this

labels_p1 = set()
if dl_p1:
  print("\n--- Simulating Process 1 ---")
  # YOUR CODE HERE: Iterate through dl_p1 and collect all unique labels.
  pass # Replace with iteration logic

# 5. Verify
print(f"\n--- Verification (Total records in source: {len(source_ex5) if source_ex5 else 'N/A'}) ---")
print(f"Unique labels collected by Process 0 (count {len(labels_p0)}): sorted {sorted(list(labels_p0))[:20]}...")
print(f"Unique labels collected by Process 1 (count {len(labels_p1)}): sorted {sorted(list(labels_p1))[:20]}...")
if labels_p0 and labels_p1:
  intersection = labels_p0.intersection(labels_p1)
  if not intersection:
    print("\nSUCCESS: No overlap in labels between Process 0 and Process 1. Sharding works as expected!")
  else:
    print(f"\nNOTE: Some overlap in labels found (count {len(intersection)}): {intersection}.")
    print("This can happen if labels are not unique per index, or if sharding logic has issues.")
    print("With MySource's label = idx % 10, an overlap in labels is expected even if indices are disjoint.")
    print("A better test would be to collect original indices if possible.")

# For a more direct test of sharding of indices:
# We can define a DataSource that returns the index itself.
class IndexSource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> int:
    return idx % self._num_records # Return the index
index_source = IndexSource(num_records=100) # Smaller source for easier inspection

idx_sampler_p0 = grain.IndexSampler(len(index_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_sampler_p1 = grain.IndexSampler(len(index_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)

# DataLoader for indices (no batching, just to see raw sampled indices)
# Note: DataLoader expects dicts. Let's make IndexSource return {'index': idx}
class IndexDictSource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int): self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str,int]:
    return {'index': idx % self._num_records}
index_dict_source = IndexDictSource(num_records=100)

# Samplers for IndexDictSource
idx_dict_sampler_p0 = grain.IndexSampler(len(index_dict_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_dict_sampler_p1 = grain.IndexSampler(len(index_dict_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)

# DataLoaders for IndexDictSource
# Pass shard_options to DataLoader as well.
if shard_options_p0 and shard_options_p1:
  dl_indices_p0 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p0, worker_count=0, shard_options=shard_options_p0)
  dl_indices_p1 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p1, worker_count=0, shard_options=shard_options_p1)
  indices_from_p0 = {item['index'] for item in dl_indices_p0} if dl_indices_p0 else set()
  indices_from_p1 = {item['index'] for item in dl_indices_p1} if dl_indices_p1 else set()

print(f"\n--- Verification of INDICES (Source size: {len(index_dict_source)}) ---")
print(f"Indices from P0 (count {len(indices_from_p0)}, shuffle=False): {sorted(list(indices_from_p0))}")
print(f"Indices from P1 (count {len(indices_from_p1)}, shuffle=False): {sorted(list(indices_from_p1))}")

if indices_from_p0 and indices_from_p1:
  idx_intersection = indices_from_p0.intersection(indices_from_p1)
  if not idx_intersection:
    print("SUCCESS: No overlap in INDICES. Sharding of data sources works correctly!")
  else:
    print(f"FAILURE: Overlap in INDICES found: {idx_intersection}")
else:
  print("Skipping index verification part as shard_options are not defined.")
  print("\nReminder: In a real distributed setup, you'd use grain.sharding.ShardByJaxProcess() "
  "and JAX would manage jax.process_index() automatically for each process.")
# @title Exercise 5: Solution
# Redefine MySource for Ex5 to include 'original_index'.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10 # Label is idx % 10
    # For better sharding verification, let's also pass the original index
    return {'image': image, 'label_test': label, 'original_index': effective_idx}

print("Redefined MySource for Ex5 to include 'original_index'.")
source_ex5 = MySource(num_records=1000)

# Redefine ConvertToFloat for Ex5 to include 'original_index'.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # This transform should pass through all keys it doesn't modify
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex5.")

# We will collect 'original_index' after batching, so batch must preserve it.
# grain.Batch by default collates features with the same name.
basic_transformations_ex5 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True) # drop_remainder for batching
  ]
print("DataSource and Transformations for Ex5 ready.")

# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1

# 3. Simulate Process 0
shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True) # drop_remainder for sharding

# Sampler for Process 0. It's important that the sampler itself is sharded.
# The DataLoader's shard_options will also apply this sharding if the sampler isn't already sharded,
# or verify consistency if it is.
sampler_p0 = grain.IndexSampler(
  num_records=len(source_ex5),
  shard_options=shard_options_p0, # Shard the sampler
  shuffle=True, # Shuffle for more realistic scenario
  num_epochs=num_epochs_for_sharding_test,
  seed=common_seed
  )

dl_p0 = grain.DataLoader(
  data_source=source_ex5,
  operations=basic_transformations_ex5,
  sampler=sampler_p0,
  worker_count=0, # Keep it simple for verification
  shard_options=shard_options_p0, # Also inform DataLoader about the sharding context
  read_options=grain.ReadOptions(num_threads=0)
  )

indices_p0 = set()
if dl_p0:
  print("--- Simulating Process 0 ---")
  for batch in dl_p0:
    indices_p0.update(batch['original_index'].tolist()) # Collect original indices
  print(f"Process 0 collected {len(indices_p0)} unique indices.")

# 4. Simulate Process 1
shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)
sampler_p1 = grain.IndexSampler(
  num_records=len(source_ex5),
  shard_options=shard_options_p1, # Shard the sampler
  shuffle=True, # Use same shuffle setting and seed for apples-to-apples comparison of sharding logic
  num_epochs=num_epochs_for_sharding_test,
  seed=common_seed # Same seed ensures shuffle order is same before sharding
  )

dl_p1 = grain.DataLoader(
  data_source=source_ex5,
  operations=basic_transformations_ex5,
  sampler=sampler_p1,
  worker_count=0,
  shard_options=shard_options_p1, # Inform DataLoader
  read_options=grain.ReadOptions(num_threads=0)
  )

indices_p1 = set()
if dl_p1:
  print("\n--- Simulating Process 1 ---")
  for batch in dl_p1:
    indices_p1.update(batch['original_index'].tolist()) # Collect original indices
  print(f"Process 1 collected {len(indices_p1)} unique indices.")

# 5. Verify
print(f"\n--- Verification of original_indices (Total records in source: {len(source_ex5)}) ---")
# Showing a few from each for brevity
print(f"Unique original_indices from P0 (first 20 sorted): {sorted(list(indices_p0))[:20]}...")
print(f"Unique original_indices from P1 (first 20 sorted): {sorted(list(indices_p1))[:20]}...")
expected_per_shard = len(source_ex5) // shard_count # Due to drop_remainder=True in ShardOptions
print(f"Expected records per shard (approx, due to drop_remainder in sharding): {expected_per_shard}")
print(f"Actual for P0: {len(indices_p0)}, P1: {len(indices_p1)}")

if indices_p0 and indices_p1:
  intersection = indices_p0.intersection(indices_p1)
  if not intersection:
    print("\nSUCCESS: No overlap in original_indices between Process 0 and Process 1. Sharding works!")
  else:
    print(f"\nFAILURE: Overlap in original_indices found (count {len(intersection)}): {sorted(list(intersection))[:20]}...")
    print("This should not happen if sharding is correct and seeds/shuffle are consistent.")
else:
  print("Could not perform intersection test as one or both sets of indices are empty.")
  total_unique_indices_seen = len(indices_p0.union(indices_p1))
  print(f"Total unique indices seen across both simulated processes: {total_unique_indices_seen}")

# With drop_remainder=True in sharding, total might be less than len(source_ex5)
# if len(source_ex5) is not divisible by shard_count.
# Example: 1000 records, 2 shards. Each gets 500. Total 1000.
# Example: 1001 records, 2 shards. drop_remainder=True means each gets 500. Total 1000. 1 record dropped.
print("\nReminder: In a real distributed JAX application:")
print("1. Each JAX process would run this script (or similar).")
print("2. shard_options = grain.sharding.ShardByJaxProcess(drop_remainder=True) would be used.")
print("3. jax.process_index() and jax.process_count() would provide the correct shard info automatically.")
print("4. The IndexSampler and DataLoader would be configured with these auto-detected shard_options.")

---

练习 6:在 JAX/Flax NNX(概念性)循环中接入 Grain

目标: 理解 Grain 的 DataLoader 如何为典型的 JAX/Flax NNX 训练循环提供数据。本练习只关注数据流,模型训练是概念性的(不做真实权重更新)。 步骤:
  1. 定义一个简单的 Flax NNX 模型:
* 创建继承自 nnx.ModuleSimpleNNXModelinit 里初始化一个 nnx.Linear 层,输入特征应与展开后的图像尺寸一致(如 32*32*3),输出特征可为 num_classes(如 10)。记得传入 rngs 完成参数初始化。 * 实现 call(self, x):将输入 x(形状 B, H, W, C)拉平成二维后送入线性层。
  1. 定义概念性的 train_step
* 这个 JAX 函数使用 @jax.jit 编译。 * 它接收 model(你的 SimpleNNXModel 实例)和来自 Grain 的 batch。 * 内部执行前向:logits = model(batch['image'])。 * 计算一个示例损失,例如 loss = jnp.mean(logits)。(本练习不需要真实损失或梯度。) * 返回 lossmodel。在真实场景中配合 nnx.Optimizer,优化器会就地更新模型参数,train_step 通常返回 loss、更新后的 model 以及下一步需要的优化器状态。
  1. 配置 DataLoader:
* 使用 MySource(输出 {'image': ..., 'label': ...})、IndexSampler(如运行几个 epoch)、以及 transformations(如 ConvertToFloatgrain.Batch)。 * 实例化一个 grain.DataLoader
  1. 编写训练循环:
* 用合适的 JAX PRNG key 初始化 SimpleNNXModel。 * 从 DataLoader 获取迭代器。 * 循环固定步数(如 100 次): * 从迭代器获取 next_batch,如果耗尽要处理 StopIteration。 * 用当前 modelnext_batch 调用 train_step。 * 偶尔打印这个示例损失。
# @title Exercise 6: Student Code
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # TODO: Initialize an nnx.Linear layer
    # self.linear = nnx.Linear(...)
    # YOUR CODE HERE
    self.linear = None # Replace this
  def __call__(self, x: jax.Array):
    # TODO: Flatten the input image (if B, H, W, C) and pass through linear layer
    # x_flat = x.reshape((x.shape[0], -1))
    # return self.linear(x_flat)
    # YOUR CODE HERE
    return x # Replace this

# 2. Define train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
  # TODO: Perform forward pass: model(batch['image'])
  # logits = ...
  # YOUR CODE HERE
  logits = model(batch['image']) # Assuming model handles it
  # TODO: Calculate a dummy loss (e.g., mean of logits)
  # loss = ...
  # YOUR CODE HERE
  loss = jnp.array(0.0) # Replace this

  # In a real scenario, you'd also compute gradients and update model parameters here.
  # For this exercise, we just return the original model and the loss.
  return model, loss

# 3. Set up DataLoader
# TODO: Instantiate MySource (from Ex1, or the one with 'original_index' if you prefer)
# source_ex6 = ...
# YOUR CODE HERE
source_ex6 = None # Replace this

# TODO: Instantiate an IndexSampler for a few epochs (e.g., 2 epochs)
# sampler_ex6 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex6 = None # Replace this

# TODO: Define transformations_ex6 (e.g., ConvertToFloat, grain.Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex6 = [...]
# YOUR CODE HERE
transformations_ex6 = [] # Replace this

# TODO: Instantiate data_loader_ex6
# data_loader_ex6 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex6 = None # Replace this

# 4. Write the Training Loop
if data_loader_ex6: # Proceed only if DataLoader is configured
  # TODO: Initialize SimpleNNXModel
  # image_height, image_width, image_channels = 32, 32, 3
  # num_classes_ex6 = 10
  # model_key = jax.random.key(0)
  # model_ex6 = SimpleNNXModel(...)
  # YOUR CODE HERE
  model_ex6 = None # Replace this

if model_ex6:
  # TODO: Get an iterator from data_loader_ex6
  # grain_iterator_ex6 = ...
  # YOUR CODE HERE
  grain_iterator_ex6 = iter([]) # Replace this

  num_steps = 100
  print(f"\nStarting conceptual training loop for {num_steps} steps...")
  for step in range(num_steps):
    try:
      # TODO: Get next_batch from iterator
      # next_batch = ...
      # YOUR CODE HERE
      next_batch = None # Replace this
      if next_batch is None:
        raise StopIteration # Simulate exhaustion if not implemented
    except StopIteration:
      print(f"DataLoader exhausted at step {step}. Ending loop.")
      break

    # TODO: Call train_step
    # loss, model_ex6 = train_step(model_ex6, next_batch) # model_ex6 isn't actually updated here
    # YOUR CODE HERE
    loss = jnp.array(0.0) # Replace this

    if step % 20 == 0 or step == num_steps - 1:
      print(f"Step {step}: Dummy Loss = {loss.item():.4f}")
      print("Conceptual training loop finished.")
    else:
      print("Model for Ex6 not initialized.")
else:
  print("DataLoader for Ex6 not configured.")
# @title Exercise 6: Solution
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dout, rngs=rngs)
  def __call__(self, x: jax.Array) -> jax.Array:
    # Assuming x is (B, H, W, C)
    batch_size = x.shape[0]
    x_flat = x.reshape((batch_size, -1)) # Flatten H, W, C dimensions
    return self.linear(x_flat)

# 2. Define conceptual train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
  # Perform forward pass
  logits = model(batch['image']) # model.call is invoked
  # Calculate a dummy loss
  loss = jnp.mean(logits**2) # Example: mean of squared logits

  # In a real training step:
  # # 1. Define a loss function.
  # def loss_fn(model):
  #   logits = model(batch['image'])
  #   # loss_value = ... (e.g., optax.softmax_cross_entropy_with_integer_labels)
  #   return jnp.mean(logits**2) # Using dummy loss from exercise
  #
  # # 2. Calculate gradients.
  # grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)
  #
  # # 3. Update the model's parameters in-place using the optimizer.
  # #    Note: The optimizer is defined outside the train step.
  # #    e.g., optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
  # optimizer.update(model, grads)
  #
  # # 4. For this exercise, we just return the original model and the loss.
  return model, loss

# 3. Set up DataLoader
# Redefine MySource for Ex6.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}
print("Redefined MySource for Ex6.")

source_ex6 = MySource(num_records=1000)
sampler_ex6 = grain.IndexSampler(
  num_records=len(source_ex6),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=2, # Run for 2 epochs
  seed=42
  )

# Redefine ConvertToFloat for Ex6.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex6.")

transformations_ex6 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True)
  ]

data_loader_ex6 = grain.DataLoader(
  data_source=source_ex6,
  operations=transformations_ex6,
  sampler=sampler_ex6,
  worker_count=2, # Use a couple of workers
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("DataLoader for Ex6 configured.")

# 4. Write the Training Loop
# Define image dimensions and number of classes
image_height, image_width, image_channels = 32, 32, 3
input_dim = image_height * image_width * image_channels
num_classes_ex6 = 10

# Initialize SimpleNNXModel
# NNX modules are typically initialized outside JIT, then their state can be passed.
# For this conceptual example, the model instance itself is passed.
model_key = jax.random.key(0)
model_ex6 = SimpleNNXModel(din=input_dim, dout=num_classes_ex6, rngs=nnx.Rngs(params=model_key))
print(f"SimpleNNXModel initialized. Input dim: {input_dim}, Output dim: {num_classes_ex6}")

# Get an iterator from data_loader_ex6
grain_iterator_ex6 = iter(data_loader_ex6)
num_steps = 100 # Total steps for the conceptual loop

print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step_idx in range(num_steps):
  try:
    # Get next_batch from iterator
    next_batch = next(grain_iterator_ex6)
  except StopIteration:
    print(f"DataLoader exhausted at step {step_idx}. Ending loop.")
    # Example: Re-initialize iterator if sampler allows multiple epochs
    # if sampler_ex6.num_epochs is None or sampler_ex6.num_epochs > 1 (and we tracked current epoch):
      # print("Re-initializing iterator for new epoch...")
      # grain_iterator_ex6 = iter(data_loader_ex6)
      # next_batch = next(grain_iterator_ex6)
    # else:
    break # Exit loop if truly exhausted

  # Call train_step
  # JAX arrays in batch are automatically handled by jax.jit
  _, loss = train_step(model_ex6, next_batch) # model_ex6 state isn't actually updated here

  if step_idx % 20 == 0 or step_idx == num_steps - 1:
    print(f"Step {step_idx}: Dummy Loss = {loss.item():.4f}") # .item() to get Python scalar from JAX array
print("Conceptual training loop finished.")

---

练习 7:数据迭代的 checkpoint 与恢复

目标: 理解如何保存并恢复 Grain 数据迭代器的状态,以便在长时间训练中实现可复现实验和断点续跑。 背景: 当你调用 iter(data_loader) 时,Grain 的 DataLoader 会产生一个迭代器(grain.PyGrainDatasetIterator)。该迭代器提供 get_state()set_state() 方法,可以捕获迭代的内部状态(例如当前位置、采样器/变换的 RNG 状态)并在之后恢复。要完整 checkpoint 实验,这个迭代器状态应和模型参数一起保存(通常可配合 Orbax)。 步骤:
  1. 配置一个 DataLoader(例如使用 MySourcenum_epochs=None 且带 seed 的 IndexSampler,以及一些基础 transformations)。
  2. 从这个 DataLoader 获取迭代器(iterator1)。
  3. 调用 next(iterator1) 迭代几次(如 3 个批次),并保存最后一次的批次。
  4. 保存状态: 调用 saved_iterator_state = iterator1.get_state()
  5. 模拟恢复:
从同一个 DataLoader 实例获取新的迭代器(iterator2)。 * 恢复状态: 调用 iterator2.set_state(saved_iterator_state)
  1. 使用 next(iterator2) 迭代一次,得到一个批次(resumed_batch)。
  2. 验证:
iterator2 得到的 resumed_batch 应当与 iterator1 在上一次批次之后本应得到的批次一致。 * 验证方式: * 从 iterator1 拿到 saved_iterator_state 后,再调用一次 next(iterator1),得到 expected_next_batch_from_iterator1。 * 比较 resumed_batchiterator2 调用 set_state 后得到)与 expected_next_batch_from_iterator1,两者内容(如图像数据)应一致。
# @title Exercise 7: Student Code
# 1. Set up DataLoader
# TODO: Instantiate MySource (e.g., from Ex1)
# source_ex7 = ...
# YOUR CODE HERE
source_ex7 = None # Replace this

# TODO: Instantiate an IndexSampler (num_epochs=None, seed=42)
# sampler_ex7 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex7 = None # Replace this

# TODO: Define transformations_ex7 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex7 = [...]
# YOUR CODE HERE
transformations_ex7 = [] # Replace this

# TODO: Instantiate data_loader_ex7
# data_loader_ex7 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex7 = None # Replace this

if data_loader_ex7:
  # 2. Get iterator1
  # TODO: iterator1 = iter(...)
  # YOUR CODE HERE
  iterator1 = iter([]) # Replace this

# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
  try:
    # TODO: last_batch_iterator1 = next(...)
    # YOUR CODE HERE
    last_batch_iterator1 = {} # Replace this
    print(f"iterator1, batch {i+1} - first label: {last_batch_iterator1.get('label', [None])[0]}")
  except StopIteration:
    print("iterator1 exhausted prematurely.")
    break

# 4. Save State
# TODO: saved_iterator_state = iterator1.get_state()
# YOUR CODE HERE
saved_iterator_state = None # Replace this
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")

# For verification: get the *next* batch from iterator1 *after* saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None: # Ensure state was actually saved
  try:
    # TODO: expected_next_batch_from_iterator1 = next(...)
    # YOUR CODE HERE
    expected_next_batch_from_iterator1 = {} # Replace this
    print(f"Expected next batch (from iterator1 after get_state) - first label: {expected_next_batch_from_iterator1.get('label', [None])[0]}")
  except StopIteration:
    print("iterator1 exhausted when trying to get expected_next_batch.")

# 5. Simulate Resumption
# TODO: Get iterator2 from the same data_loader_ex7
# iterator2 = iter(...)
# YOUR CODE HERE
iterator2 = iter([]) # Replace this

if saved_iterator_state is not None:
  # TODO: iterator2.set_state(...)
  # YOUR CODE HERE
  print("\n--- Resumed Iteration (iterator2) ---")
  print("Iterator state restored to iterator2.")
else:
  print("\nSkipping resumption, saved_iterator_state is None.")

# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None: # Only if state was set
  try:
    # TODO: resumed_batch = next(...)
    # YOUR CODE HERE
    resumed_batch = {} # Replace this
    print(f"Resumed batch (from iterator2 after set_state) - first label: {resumed_batch.get('label', [None])[0]}")
  except StopIteration:
    print("iterator2 exhausted immediately after set_state.")

# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
  # Compare 'image' data of the first element in the batch
  # TODO: Perform comparison (e.g., np.allclose on image data)
  # are_identical = np.allclose(...)
  # YOUR CODE HERE
  are_identical = False # Replace this

  if are_identical:
    print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
  else:
    print("\nFAILURE: Resumed batch differs from the expected next batch.")
    # print(f"Expected image data (sample): {expected_next_batch_from_iterator1['image'][0,0,0,:3]}")
    # print(f"Resumed image data (sample): {resumed_batch['image'][0,0,0,:3]}")
elif saved_iterator_state is not None:
  print("\nVerification inconclusive: could not obtain both expected and resumed batches.")
else:
  print("DataLoader for Ex7 not configured.")
# @title Exercise 7: Solution
# 1. Set up DataLoader
# Redefine MySource for Ex7.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    # Add original_index for easier verification if needed, though label might suffice
    return {'image': image, 'label': label, 'original_index': effective_idx}
print("Redefined MySource for Ex7.")

source_ex7 = MySource(num_records=1000)
sampler_ex7 = grain.IndexSampler(
  num_records=len(source_ex7),
  shard_options=grain.NoSharding(),
  shuffle=True, # Shuffling makes the test more robust
  num_epochs=None, # Indefinite iteration
  seed=42
  )

# Redefine ConvertToFloat for Ex7.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex7.")

transformations_ex7 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True)
  ]

data_loader_ex7 = grain.DataLoader(
  data_source=source_ex7,
  operations=transformations_ex7,
  sampler=sampler_ex7,
  worker_count=0, # Simpler for state verification, but works with >0 too
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("DataLoader for Ex7 configured.")

# 2. Get iterator1
iterator1 = iter(data_loader_ex7)

# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
  try:
    last_batch_iterator1 = next(iterator1)
    # Using 'original_index' for more robust check than just 'label'
    print(f"iterator1, batch {i+1} - first original_index: {last_batch_iterator1['original_index'][0]}")
  except StopIteration:
    print("iterator1 exhausted prematurely.")
    break

# 4. Save State
# Make a deep copy if you plan to continue using iterator1 and don't want
# its state object to be modified if Python passes by reference internally (usually not an issue for simple state).
# For PyGrainDatasetIterator, get_state() returns a new state object.
saved_iterator_state = iterator1.get_state()
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")

# For verification: get the next batch from iterator1 after saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None:
  try:
    expected_next_batch_from_iterator1 = next(iterator1)
    print(f"Expected next batch (from iterator1 after get_state) - first original_index: {expected_next_batch_from_iterator1['original_index'][0]}")
  except StopIteration:
    print("iterator1 exhausted when trying to get expected_next_batch.")

# 5. Simulate Resumption
# Get a new iterator from the same DataLoader instance.
iterator2 = iter(data_loader_ex7)
if saved_iterator_state is not None:
  iterator2.set_state(saved_iterator_state)
  print("\n--- Resumed Iteration (iterator2) ---")
  print("Iterator state restored to iterator2.")
else:
  print("\nSkipping resumption, saved_iterator_state is None.")

# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None:
  try:
    resumed_batch = next(iterator2)
    print(f"Resumed batch (from iterator2 after set_state) - first original_index: {resumed_batch['original_index'][0]}")
  except StopIteration:
    print("iterator2 exhausted immediately after set_state. This means the saved state was at the very end.")

# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
  # Compare 'image' data and 'original_index' of the first element in the batch for robustness
  # (Labels might repeat, indices are better for this check if available)
  expected_img_sample = expected_next_batch_from_iterator1['image'][0]
  resumed_img_sample = resumed_batch['image'][0]
  expected_idx_sample = expected_next_batch_from_iterator1['original_index'][0]
  resumed_idx_sample = resumed_batch['original_index'][0]
  are_indices_identical = (expected_idx_sample == resumed_idx_sample)
  are_images_identical = np.allclose(expected_img_sample, resumed_img_sample)

  are_identical = are_indices_identical and are_images_identical

  if are_identical:
      print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
  else:
      print("\nFAILURE: Resumed batch differs from the expected next batch.")
      if not are_indices_identical:
          print(f"  - Mismatch in first original_index: Expected {expected_idx_sample}, Got {resumed_idx_sample}")
      if not are_images_identical:
          print(f"  - Mismatch in image data for first element.")
          # print(f"    Expected image data (sample [0,0,0]): {expected_img_sample[0,0,0]}")
          # print(f"    Resumed image data (sample [0,0,0]): {resumed_img_sample[0,0,0]}")
elif saved_iterator_state is not None: # If state was saved but verification couldn't complete
  if expected_next_batch_from_iterator1 is None and resumed_batch is None:
    print("\nVERIFICATION NOTE: Both iterators seem to be at the end of the dataset after the initial iterations. This is valid if the dataset was short.")
  else:
    print("\nVerification inconclusive: could not obtain both expected and resumed batches for comparison.")
    print(f" expected_next_batch_from_iterator1 is None: {expected_next_batch_from_iterator1 is None}")
    print(f" resumed_batch is None: {resumed_batch is None}")
else: # If DataLoader itself wasn't configured
  print("DataLoader for Ex7 not configured.")

---

总结与进一步探索

恭喜完成全部练习!现在你应该已经理解:

  • Grain 的基本组件(DataSourceSamplerOperations)。
  • 如何构建并使用 grain.DataLoader 进行高效(含并行)数据输入。
  • 如何实现自定义的确定性和随机变换。
  • 分布式训练中的数据分片基础。
  • Grain 迭代器如何融入 JAX/Flax NNX 训练循环。
  • 如何保存和恢复数据迭代器状态以保证可复现。
  • 要点回顾(来自讲义):
    • 使用 Grain: 解决 JAX 的数据瓶颈,提升性能。
    • 提升速度: 设置 DataLoader(worker_count > 0) 以启用并行。
    • 确保可复现: 使用采样器/seed,并在 RandomMapTransform 中使用提供的 rng
    • 分布式: 使用 grain.sharding.ShardByJaxProcess(或手动 ShardOptions)进行 JAX 分片。
    • 保存一切: 将数据迭代器状态与模型状态一起 checkpoint(例如通过 Orbax 完整保存)。
    • 进一步探索:
      • Orbax 集成: 在真实项目中实现健壮的 checkpoint,可将 Grain 与 Orbax 集成,原子地保存/加载 Grain 迭代器状态与 Flax 模型参数、优化器状态。注意: 若从 NNX v0.10 迁移,带 RNG 的模型(如 DropoutBatchNorm)的 checkpoint 结构已变,需要使用迁移脚本转换为 v0.11,详见 官方迁移指南
      • 不同数据源: 尝试通过合适的 DataSource 实现或 TFDS 等库读取多种落盘格式(如 TFRecord、RecordIO)。
      • 性能分析: 使用 JAX 的 profiling 工具,在更复杂场景下识别并优化数据加载瓶颈。

      希望这些练习能助你在 JAX 与 Grain 的学习之路上走得更远!