欢迎!这个 Colab 笔记本包含练习,帮助你学习谷歌的 Grain 库,在 JAX 中实现高效数据加载。这些练习面向熟悉 PyTorch、正在探索 JAX 生态(包括全新的 Flax NNX API)的开发者。
本笔记本目标:grain.DataLoader 进行顺序和并行的数据加载。为了在通常只提供单个 CPU/GPU 的 Colab 中更好地演示并行和分片概念,笔记本一开始就配置 JAX 模拟 8 个 CPU 设备,使用 XLA_FLAGS 和 chex.set_n_cpu_devices(8) 实现。
# Environment Setup
# This cell configures the environment to simulate multiple CPU devices
# and installs necessary libraries.
# IMPORTANT: RUN THIS CELL FIRST. If you encounter issues with JAX device
# counts later, try 'Runtime -> Restart runtime' in the Colab menu
# and run this cell again before any others.
import os
# Configure JAX to see 8 virtual CPU devices.
# This must be done before JAX is imported for the first time in a session.
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
# Install libraries
!pip install -q jax-ai-stack==2025.9.3
print("Libraries installed.")
# Now, import chex and attempt to set_n_cpu_devices.
# This must be called after setting XLA_FLAGS and before JAX initializes its backends.
import chex
try:
chex.set_n_cpu_devices(8)
print("chex.set_n_cpu_devices(8) called successfully.")
except RuntimeError as e:
print(f"Note on chex.set_n_cpu_devices: {e}")
print("This usually means JAX was already initialized. The XLA_FLAGS environment variable should still apply.")
print("If you see issues with device counts, ensure you 'Restart runtime' and run this cell first.")
# Verify JAX device count
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX found {jax.device_count()} devices.")
print(f"JAX devices: {jax.devices()}")
if jax.device_count() != 8:
print("\nWARNING: JAX does not see 8 devices. Parallelism/sharding exercises might not behave as expected.")
print("Please try 'Runtime -> Restart runtime' and run this setup cell again first.")
# Common imports for the exercises
import jax.numpy as jnp
import numpy as np
import grain.python as grain # Main Grain API
from flax import nnx # Flax NNX API
import time # For simulating work and observing performance
import copy # For checkpointing example
from typing import Dict, Any, List # For type hints
import dataclasses # For ShardOptions if needed manually
import functools # For functools.partial
print("Imports complete. Setup finished.")
正如讲座中强调的那样,JAX 在数值计算上非常快,尤其在加速器上。然而,如果数据加载效率低下,速度会被拖慢。标准 Python 数据加载常常受限于 I/O、CPU 密集的变换以及全局解释器锁(GIL)。
Grain 是谷歌为 JAX 打造的高性能数据加载方案,核心目标是:概念上,Grain 的 DataLoader 类似于 PyTorch 的 torch.utils.data.DataLoader,负责数据读取、变换、批处理和并行。
grain.DataLoader API 的核心组件:
DataSource:提供访问单条原始数据记录的能力(需要实现 len 和 getitem)。Sampler:决定记录的加载顺序,并为随机操作提供种子,确保可复现性。Operations:一系列顺序执行的变换(如增强、过滤、批处理)。下面开始练习!
---
grain.DataLoader(顺序模式)DataSource、IndexSampler、简单的 MapTransform,以及顺序模式(worker_count=0)下的 grain.DataLoader。
步骤:
MySource,自定义的 RandomAccessDataSource。init:保存 num_records。
* len:返回 num_records。
* getitem:给定 idx,返回字典 {'image': image_array, 'label': label_int}。
image_array 应是形状 (32, 32, 3)、dtype=np.uint8 的 NumPy 数组,值可依赖 idx(例如 np.ones(...) * (idx % 255))。
* label_int 是整数(例如 idx % 10)。
* 处理多轮次时的索引回绕:idx = idx % self.num_records。
MySource。IndexSampler,打乱顺序,跑 1 个 epoch,并使用固定 seed。operations 列表:grain.MapTransform 的 ConvertToFloat 类,将 image 转为 np.float32 并归一化到 [0, 1]。
* 使用 grain.Batch 将样本组成 64 的批次,舍弃尾部不足。
worker_count=0(调试/顺序模式)实例化 grain.DataLoader。MySource 是内存数据,使用 read_options=grain.ReadOptions(num_threads=0) 关闭 Grain 的内部读取线程。
DataLoader,取第一批数据并打印图像和标签的形状。# @title Exercise 1: Student Code
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
# TODO: Return the total number of records
# YOUR CODE HERE
return 0 # Replace this
def __getitem__(self, idx: int) -> Dict[str, Any]:
# TODO: Handle potential index wrap-around for multiple epochs
# effective_idx = ...
# YOUR CODE HERE
effective_idx = idx # Replace this
# TODO: Simulate loading data: an image and a label
# image = np.ones(...) * (effective_idx % 255)
# label = effective_idx % 10
# YOUR CODE HERE
image = np.zeros((32,32,3), dtype=np.uint8) # Replace this
label = 0 # Replace this
return {'image': image, 'label': label}
# 2. Instantiate MySource
# TODO: Create an instance of MySource
# source = ...
# YOUR CODE HERE
source = None # Replace this
# 3. Create an IndexSampler
# TODO: Create an IndexSampler that shuffles, runs for 1 epoch, and uses seed 42.
# num_records should be len(source).
# index_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
index_sampler = None # Replace this
# 4. Define Operations
# TODO: Define ConvertToFloat transform
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Convert 'image' to float32 and normalize to [0, 1].
# Keep 'label' as is.
# YOUR CODE HERE
image = features['image'] # Replace this
return {'image': image.astype(np.float32) / 255.0, 'label': features['label']}
# TODO: Create a list of transformations: ConvertToFloat instance, then grain.Batch
batch_size = 64, drop_remainder = True
# transformations = [...]
# YOUR CODE HERE
transformations = [] # Replace this
# 5. Instantiate DataLoader
# TODO: Create a DataLoader with worker_count=0 and appropriate read_options
# data_loader_sequential = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_sequential = None # Replace this
# 6. Iterate and print batch info
if data_loader_sequential:
print("DataLoader configured sequentially.")
data_iterator_seq = iter(data_loader_sequential)
try:
first_batch_seq = next(data_iterator_seq)
print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}")
print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}")
# Example: Check a value from the first image of the first batch
print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
except StopIteration:
print("Sequential DataLoader is empty or exhausted.")
else:
print("Sequential DataLoader not configured yet.")
# @title Exercise 1: Solution
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records # Handle wrap-around
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
# 2. Instantiate MySource
source = MySource(num_records=1000)
print(f"DataSource created with {len(source)} records.")
# 3. Create an IndexSampler
index_sampler = grain.IndexSampler(
num_records=len(source),
shard_options=grain.NoSharding(), # No sharding for this exercise
shuffle=True,
num_epochs=1, # Run for 1 epoch
seed=42
)
print("IndexSampler created.")
# 4. Define Operations
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# Convert 'image' to float32 and normalize to [0, 1].
# Keep 'label' as is.
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
transformations = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations defined.")
# 5. Instantiate DataLoader
data_loader_sequential = grain.DataLoader(
data_source=source,
operations=transformations,
sampler=index_sampler,
worker_count=0, # Sequential mode
shard_options=grain.NoSharding(), # Explicitly no sharding for this loader instance
read_options=grain.ReadOptions(num_threads=0) # Dataset is in-memory
)
# 6. Iterate and print batch info
if data_loader_sequential:
print("DataLoader configured sequentially.")
data_iterator_seq = iter(data_loader_sequential)
try:
first_batch_seq = next(data_iterator_seq)
print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}") # Expected: (64, 32, 32, 3)
print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}") # Expected: (64,)
# Example: Check a value from the first image of the first batch
print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
except StopIteration:
print("Sequential DataLoader is empty or exhausted.")
else:
print("Sequential DataLoader not configured yet.")
---
worker_count 开启并行worker_count > 0 如何启用多进程以加速数据加载。
步骤:
MySource、IndexSampler(或新建一个,比如 num_epochs=None 的无限 epoch)、以及练习 1 的 transformations。MySource:在 getitem 里加入一次 time.sleep(0.01)(10ms),模拟每条样本的 I/O 或 CPU 开销。grain.DataLoader(例如 data_loader_parallel),这次将 worker_count 设为大于 0 的值(如 2 或 4)。记得我们在模拟 8 个 CPU。time.sleep 后,并行 loader 应该更快。worker_count > 0 时,Grain 会使用多进程。这意味着所有组件(DataSource、Sampler、Operations 以及自定义变换实例)都必须能被 Python 的 pickle 模块序列化。一般的类和函数没问题,但要避免复杂的闭包或不可序列化的对象。
# @title Exercise 2: Student Code
# 1. Reuse/Recreate components (DataSource with simulated work, Sampler, Operations)
# TODO: Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
# YOUR CODE HERE
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
# TODO: Add time.sleep(0.01) to simulate work
# YOUR CODE HERE
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
# TODO: Instantiate MySourceWithWork
# source_with_work = ...
# YOUR CODE HERE
source_with_work = None # Replace this
# TODO: Create a new IndexSampler (e.g., for indefinite epochs, num_epochs=None)
# Or reuse the one from Ex1 if you reset it or it's for multiple epochs.
# For simplicity, let's create one for indefinite epochs.
# parallel_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
parallel_sampler = None # Replace this
# Transformations can be reused from Exercise 1
# transformations = [ConvertToFloat(), grain.Batch(batch_size=64, drop_remainder=True)]
# (Assuming ConvertToFloat is defined from Ex1 solution)
# 2. Instantiate DataLoader with worker_count > 0
# TODO: Set num_workers (e.g., 4)
# num_workers = ...
# YOUR CODE HERE
num_workers = 0 # Replace this
# TODO: Create data_loader_parallel
# data_loader_parallel = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_parallel = None # Replace this
# 3. Iterate and print batch info
if data_loader_parallel:
print(f"DataLoader configured with worker_count={num_workers}.")
data_iterator_parallel = iter(data_loader_parallel)
try:
first_batch_parallel = next(data_iterator_parallel)
print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
except StopIteration:
print("Parallel DataLoader is empty or exhausted.")
else:
print("Parallel DataLoader not configured yet.")
# 4. (Optional) Timing comparison
# Re-create sequential loader with MySourceWithWork for a fair comparison
if source_with_work and transformations and index_sampler: # index_sampler from Ex1
data_loader_seq_with_work = grain.DataLoader(
data_source=source_with_work,
operations=transformations, # Reusing from Ex1
sampler=index_sampler, # Reusing from Ex1 (ensure it's fresh or allows re-iteration)
worker_count=0,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
num_batches_to_test = 5 # Small number for quick test
if data_loader_seq_with_work:
print(f"\nTiming test for {num_batches_to_test} batches:")
# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time = time.time()
try:
for i in range(num_batches_to_test):
batch = next(iterator_seq)
if i == 0: print(f" Seq batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
except StopIteration:
print("Sequential loader exhausted early.")
end_time = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time - start_time:.4f} seconds")
if data_loader_parallel:
# Parallel
# Ensure sampler is fresh for parallel loader if it was used above
# For this optional part, let's use a fresh sampler for the parallel loader
# to avoid StopIteration if the previous sampler was single-epoch and exhausted.
fresh_parallel_sampler = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Indefinite
seed=43 # Different seed or same, for this test it's about speed
)
data_loader_parallel_for_timing = grain.DataLoader(
data_source=source_with_work,
operations=transformations, # Reusing from Ex1
sampler=fresh_parallel_sampler,
worker_count=num_workers if num_workers > 0 else 2, # Ensure parallelism
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
iterator_parallel = iter(data_loader_parallel_for_timing)
start_time = time.time()
try:
for i in range(num_batches_to_test):
batch = next(iterator_parallel)
if i == 0: print(f" Parallel batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
except StopIteration:
print("Parallel loader exhausted early.")
end_time = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers if num_workers > 0 else 2} workers) took: {end_time - start_time:.4f} seconds")
else:
print("Skipping optional timing: source_with_work, transformations, or index_sampler not defined.")
# @title Exercise 2: Solution
# 1. Reuse/Recreate components
# Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
time.sleep(0.01) # Simulate 10ms of work per item
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
source_with_work = MySourceWithWork(num_records=1000)
print(f"MySourceWithWork created with {len(source_with_work)} records.")
# Sampler for parallel loading (indefinite epochs for robust testing)
parallel_sampler = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Run indefinitely
seed=42
)
print("Parallel IndexSampler created.")
# Transformations can be reused from Exercise 1 solution
# Ensure ConvertToFloat is defined (it was in Ex1 solution cell)
if 'ConvertToFloat' not in globals(): # Basic check
class ConvertToFloat(grain.MapTransform): # Redefine if not in current scope
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
print("Redefined ConvertToFloat for safety.")
transformations_ex2 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex2 ready.")
# 2. Instantiate DataLoader with worker_count > 0
num_workers = 4 # Use 4 workers; JAX is configured for 8 virtual CPUs
# Max useful workers often related to num CPU cores available.
data_loader_parallel = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=parallel_sampler,
worker_count=num_workers,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0) # Data source simulates work but is "in-memory"
)
# 3. Iterate and print batch info
if data_loader_parallel:
print(f"DataLoader configured with worker_count={num_workers}.")
data_iterator_parallel = iter(data_loader_parallel)
try:
first_batch_parallel = next(data_iterator_parallel)
print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
except StopIteration:
print("Parallel DataLoader is empty or exhausted.")
else:
print("Parallel DataLoader not configured yet.")
# 4. (Optional) Timing comparison
# Create a fresh IndexSampler for the sequential loader for a fair comparison start
# (num_epochs=1 to match typical test for sequential pass)
seq_sampler_for_timing = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=1, # Single epoch for this timing test
seed=42
)
data_loader_seq_with_work = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=seq_sampler_for_timing,
worker_count=0,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
num_batches_to_test = 5 # Number of batches to fetch for timing
print(f"\nTiming test for {num_batches_to_test} batches (each item has 0.01s simulated work):")
# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time_seq = time.time()
try:
for i in range(num_batches_to_test):
batch_seq = next(iterator_seq)
if i == 0 and num_batches_to_test > 0 : print(f" Seq batch 1 label sum: {batch_seq['label'].sum()}") # to ensure work is done
except StopIteration:
print(f"Sequential loader exhausted before {num_batches_to_test} batches.")
end_time_seq = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time_seq - start_time_seq:.4f} seconds")
# Parallel
# Use a fresh sampler for the parallel loader for timing to ensure it's not exhausted
# and runs for enough batches.
parallel_sampler_for_timing = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Indefinite, or ensure enough for num_batches_to_test
seed=43 # Can be same or different seed
)
data_loader_parallel_for_timing = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=parallel_sampler_for_timing,
worker_count=num_workers,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
iterator_parallel_timed = iter(data_loader_parallel_for_timing)
start_time_parallel = time.time()
try:
for i in range(num_batches_to_test):
batch_par = next(iterator_parallel_timed)
if i == 0 and num_batches_to_test > 0 : print(f" Parallel batch 1 label sum: {batch_par['label'].sum()}") # to ensure work is done
except StopIteration:
print(f"Parallel loader exhausted before {num_batches_to_test} batches.")
end_time_parallel = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers} workers) took: {end_time_parallel - start_time_parallel:.4f} seconds")
if end_time_parallel - start_time_parallel < end_time_seq - start_time_seq:
print("Parallel loading was faster, as expected!")
else:
print("Parallel loading was not significantly faster. This might happen for very small num_batches_to_test due to overhead, or if simulated work is too little.")
---
MapTransform)OneHotEncodeLabel,继承自 grain.MapTransform。init 方法接收 num_classes。
* map(self, features: Dict[str, Any]) 需要:
* 读取传入的 features 字典。
* 将 features['label'](整数)转换为 np.float32 的 one-hot NumPy 向量,长度为 num_classes。
* 用新的 one-hot 向量更新 features['label']。
* 返回修改后的 features 字典。
MySource(不包含 time.sleep)和 IndexSampler(或重新创建)。operations 列表,包含:OneHotEncodeLabel 实例(例如 num_classes=10,对应 MySource 中的 idx % 10)。
* ConvertToFloat 变换(如果图像还未转换)。
* grain.Batch。
grain.DataLoader(worker_count 可选 0 或 >0)。# @title Exercise 3: Student Code
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
def __init__(self, num_classes: int):
# TODO: Store num_classes
# YOUR CODE HERE
self._num_classes = 0 # Replace this
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
label = features['label']
# TODO: Create one-hot encoded version of the label
# one_hot_label = np.zeros(...)
# one_hot_label[label] = 1.0
# YOUR CODE HERE
one_hot_label = np.array([label]) # Replace this
features['label'] = one_hot_label
return features
# 2. Reuse/Create DataSource and Sampler
# TODO: Instantiate MySource (from Ex1, no sleep)
# source_ex3 = ...
# YOUR CODE HERE
source_ex3 = None # Replace this
# TODO: Instantiate an IndexSampler (e.g., from Ex1, or a new one)
# sampler_ex3 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex3 = None # Replace this
# 3. Create new list of operations
# TODO: Instantiate OneHotEncodeLabel
num_classes_for_ohe = 10
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# YOUR CODE HERE
one_hot_encoder = None # Replace this
# TODO: Define transformations_ex3 list including one_hot_encoder,
# ConvertToFloat (if not already applied), and grain.Batch
# (Assuming ConvertToFloat is defined from Ex1 solution)
# transformations_ex3 = [...]
# YOUR CODE HERE
transformations_ex3 = [] # Replace this
# 4. Instantiate DataLoader
# TODO: Create data_loader_ex3
# data_loader_ex3 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex3 = None # Replace this
# 5. Iterate and print batch info
if data_loader_ex3:
print("DataLoader with OneHotEncodeLabel configured.")
iterator_ex3 = iter(data_loader_ex3)
try:
first_batch_ex3 = next(iterator_ex3)
print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (batch_size, num_classes)
if first_batch_ex3['label'].size > 0:
print(f"Custom MapTransform - Example one-hot label: {first_batch_ex3['label'][0]}")
except StopIteration:
print("DataLoader for Ex3 is empty or exhausted.")
else:
print("DataLoader for Ex3 not configured yet.")
# @title Exercise 3: Solution
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
def __init__(self, num_classes: int):
self._num_classes = num_classes
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
label_scalar = features['label']
one_hot_label = np.zeros(self._num_classes, dtype=np.float32)
one_hot_label[label_scalar] = 1.0
# Create a new dictionary to avoid modifying the input dict in place if it's reused
# by other transforms or parts of the pipeline, though often direct modification is fine.
# For safety and clarity, let's return a new dict or an updated copy.
updated_features = features.copy()
updated_features['label'] = one_hot_label
return updated_features
# 2. Reuse/Create DataSource and Sampler
# Using MySource from Exercise 1 solution (no artificial sleep)
if 'MySource' not in globals(): # Basic check
class MySource(grain.RandomAccessDataSource): # Redefine if not in current scope
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex3.")
source_ex3 = MySource(num_records=1000)
sampler_ex3 = grain.IndexSampler(
num_records=len(source_ex3),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=1,
seed=42
)
print("DataSource and Sampler for Ex3 ready.")
# 3. Create new list of operations
num_classes_for_ohe = 10 # Matches idx % 10 in MySource
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# Ensure ConvertToFloat is defined
if 'ConvertToFloat' not in globals():
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']} # Pass label through
print("Redefined ConvertToFloat for Ex3.")
transformations_ex3 = [
ConvertToFloat(), # Apply first to have float images
one_hot_encoder, # Then one-hot encode labels
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex3 defined.")
# 4. Instantiate DataLoader
data_loader_ex3 = grain.DataLoader(
data_source=source_ex3,
operations=transformations_ex3,
sampler=sampler_ex3,
worker_count=0, # Can be > 0 as well, OneHotEncodeLabel is picklable
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
# 5. Iterate and print batch info
if data_loader_ex3:
print("DataLoader with OneHotEncodeLabel configured.")
iterator_ex3 = iter(data_loader_ex3)
try:
first_batch_ex3 = next(iterator_ex3)
print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (64, 10)
if first_batch_ex3['label'].size > 0:
print(f"Custom MapTransform - Example one-hot label (first item): {first_batch_ex3['label'][0]}")
original_label_example = np.argmax(first_batch_ex3['label'][0])
print(f"Custom MapTransform - Decoded original label (first item): {original_label_example}")
except StopIteration:
print("DataLoader for Ex3 is empty or exhausted.")
else:
print("DataLoader for Ex3 not configured yet.")
---
RandomMapTransform)RandomBrightnessAdjust,继承自 grain.RandomMapTransform。random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any] 应当:
* 接收 features 和 rng(NumPy 随机数生成器)。
* 所有随机操作都要使用传入的 rng,这样在相同种子下同一条记录得到相同的“随机”增强。
* 用 rng.uniform(0.7, 1.3) 生成随机亮度因子。
* 将 features['image'](假设已是 float 且归一化)乘以该因子。
* 使用 np.clip() 将图像裁剪到 [0.0, 1.0]。
* 返回修改后的 features。
MySource、IndexSampler(确保设置 seed)、以及前面练习的 ConvertToFloat。operations 列表,包含 ConvertToFloat、你的 RandomBrightnessAdjust 以及 grain.Batch。DataLoader(dl_run1、dl_run2)。dl_run1 的第一批数据,打印一个像素值。num_epochs=1),重置迭代器或重建采样器,再获取 dl_run2 的第一批数据,打印同一个像素值。dl_run2 的 IndexSampler 换个 seed,观察像素值发生变化。# @title Exercise 4: Student Code
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
# TODO: Ensure image is float (e.g. by placing ConvertToFloat before this in ops)
image = features['image']
# TODO: Generate a random brightness factor using the provided rng
# brightness_factor = rng.uniform(...)
# YOUR CODE HERE
brightness_factor = 1.0 # Replace this
# TODO: Apply brightness adjustment and clip
# adjusted_image = np.clip(...)
# YOUR CODE HERE
adjusted_image = image # Replace this
# Create a new dictionary or update a copy
updated_features = features.copy()
updated_features['image'] = adjusted_image
return updated_features
# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
# TODO: Instantiate MySource (from Ex1)
# source_ex4 = ...
# YOUR CODE HERE
source_ex4 = None # Replace this
# TODO: Instantiate an IndexSampler with a seed (e.g., seed=42, num_epochs=1 or None)
# sampler_ex4_seed42 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed42 = None # Replace this
# (Assuming ConvertToFloat is defined from Ex1 solution)
# 3. Create list of operations
# TODO: Instantiate RandomBrightnessAdjust
# random_brightness_adjuster = ...
# YOUR CODE HERE
random_brightness_adjuster = None # Replace this
# TODO: Define transformations_ex4 list: ConvertToFloat, random_brightness_adjuster, grain.Batch
# transformations_ex4 = [...]
# YOUR CODE HERE
transformations_ex4 = [] # Replace this
# 4. Instantiate two DataLoaders with the same config
# TODO: Create dl_run1
# dl_run1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run1 = None # Replace this
# TODO: Create dl_run2 (using the exact same sampler instance or a new one with the same seed)
# dl_run2 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run2 = None # Replace this
# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx, H, W, C
if dl_run1:
print("--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)
try:
batch1_run1 = next(iterator_run1)
value_run1 = batch1_run1['image'][pixel_to_check]
print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
print("dl_run1 exhausted.")
value_run1 = None
else:
print("dl_run1 not configured.")
value_run1 = None
if dl_run2:
print("\n--- Run 2 (seed 42, same sampler) ---")
# If sampler_ex4_seed42 was single-epoch and already used by dl_run1,
# dl_run2 might be empty. For robust test, ensure sampler allows re-iteration
# or use a new sampler instance with the same seed.
# If sampler_ex4_seed42 had num_epochs=None, iter(dl_run2) is fine.
# If num_epochs=1, you might need to re-create sampler_ex4_seed42 for dl_run2
# or ensure dl_run1 didn't exhaust it (e.g. by not fully iterating it).
# For this exercise, assume sampler_ex4_seed42 can be re-used or is fresh for dl_run2.
iterator_run2 = iter(dl_run2)
try:
batch1_run2 = next(iterator_run2)
value_run2 = batch1_run2['image'][pixel_to_check]
print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
# 7. Verify
if value_run1 is not None and value_run2 is not None:
if np.allclose(value_run1, value_run2):
print("\nSUCCESS: Pixel values are identical. Randomness is reproducible!")
else:
print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}")
except StopIteration:
print("dl_run2 exhausted. This might happen if the sampler was single-epoch and already used.")
value_run2 = None
else:
print("dl_run2 not configured.")
# 8. (Optional) Test with a different seed
# TODO: Create sampler_ex4_seed100 (seed=100)
# sampler_ex4_seed100 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed100 = None # Replace this
# TODO: Create dl_run3 with sampler_ex4_seed100
# dl_run3 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run3 = None # Replace this
if dl_run3:
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)
try:
batch1_run3 = next(iterator_run3)
value_run3 = batch1_run3['image'][pixel_to_check]
print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
if value_run1 is not None and not np.allclose(value_run1, value_run3):
print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new seed.")
elif value_run1 is not None:
print("NOTE: Pixel values are the same as Run 1. Check seed or logic.")
except StopIteration:
print("dl_run3 exhausted.")
else:
print("\nOptional part (dl_run3) not configured.")
# @title Exercise 4: Solution
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
image = features['image'] # Assumes image is already float, e.g. from ConvertToFloat
# Generate a random brightness factor using the provided rng
brightness_factor = rng.uniform(0.7, 1.3)
# Apply brightness adjustment and clip
adjusted_image = image * brightness_factor
adjusted_image = np.clip(adjusted_image, 0.0, 1.0)
updated_features = features.copy()
updated_features['image'] = adjusted_image
return updated_features
# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
if 'MySource' not in globals(): # Basic check for MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex4.")
source_ex4 = MySource(num_records=1000)
# Sampler with a fixed seed. num_epochs=None allows re-iteration for multiple DataLoaders.
# If num_epochs=1, the sampler instance can only be fully iterated once.
# For this test, using num_epochs=None or re-creating the sampler for each DataLoader is safest.
# Let's use num_epochs=None to allow the same sampler instance to be used.
sampler_ex4_seed42 = grain.IndexSampler(
num_records=len(source_ex4),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Allow indefinite iteration
seed=42
)
print("DataSource and Sampler (seed 42) for Ex4 ready.")
if 'ConvertToFloat' not in globals(): # Basic check for ConvertToFloat
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
print("Redefined ConvertToFloat for Ex4.")
# 3. Create list of operations
random_brightness_adjuster = RandomBrightnessAdjust()
transformations_ex4 = [
ConvertToFloat(),
random_brightness_adjuster,
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex4 defined.")
# 4. Instantiate two DataLoaders with the same config
# Using worker_count > 0 to also test picklability of RandomBrightnessAdjust
num_workers_ex4 = 2
dl_run1 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed42, # Same sampler instance
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
dl_run2 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed42, # Same sampler instance
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print(f"DataLoaders for Run1 and Run2 created with worker_count={num_workers_ex4}.")
# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx=0, H=0, W=0, C=0
print("\n--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)
try:
batch1_run1 = next(iterator_run1)
value_run1 = batch1_run1['image'][pixel_to_check]
print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
print("dl_run1 exhausted.")
value_run1 = None
print("\n--- Run 2 (seed 42, same sampler instance) ---")
iterator_run2 = iter(dl_run2) # Gets a new iterator from the DataLoader
try:
batch1_run2 = next(iterator_run2)
value_run2 = batch1_run2['image'][pixel_to_check]
print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
# 7. Verify
if value_run1 is not None and value_run2 is not None:
if np.allclose(value_run1, value_run2):
print("\nSUCCESS: Pixel values are identical. Randomness is reproducible with the same sampler instance!")
else:
print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}. This shouldn't happen if sampler is re-used correctly.")
except StopIteration:
print("dl_run2 exhausted.")
# 8. (Optional) Test with a different seed
sampler_ex4_seed100 = grain.IndexSampler(
num_records=len(source_ex4),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None,
seed=100 # Different seed
)
dl_run3 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed100, # Sampler with different seed
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("\nDataLoader for Run3 (seed 100) created.")
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)
try:
batch1_run3 = next(iterator_run3)
value_run3 = batch1_run3['image'][pixel_to_check]
print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
if value_run1 is not None and not np.allclose(value_run1, value_run3):
print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new sampler seed.")
elif value_run1 is not None:
print("NOTE: Pixel values are the same as Run 1. This is unexpected if seeds are different.")
except StopIteration:
print("dl_run3 exhausted.")
---
jax.process_index() 获取自身 ID,调用 jax.process_count() 获取进程总数。grain.sharding.ShardByJaxProcess() 会自动使用这些值完成分片。
由于我们在单个 Colab 笔记本中(即使模拟了多个虚拟设备,也只有一个 JAX 进程),无法真正跑多个进程。因此我们手动创建 grain.ShardOptions,模拟两个不同进程的情况。
MySource 和 transformations(例如 ConvertToFloat 和 grain.Batch)。shard_count = 2。shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True)。
* 用 shard_options_p0 创建 IndexSampler(sampler_p0),确保开启 shuffle 并使用公共 seed(如 42)。
* 使用 sampler_p0 和 shard_options_p0 创建 DataLoader(dl_p0),并把 shard_options_p0 传递给 DataLoader。
* 遍历 dl_p0,从前几个 batch(或 num_epochs=1 时全部批次)收集所有出现的标签。
shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)。
* 使用 shard_options_p1(与 sampler_p0 相同的 seed)创建 IndexSampler(sampler_p1)。
* 使用 sampler_p1 和 shard_options_p1 创建 DataLoader(dl_p1)。
* 遍历 dl_p1 并收集所有出现的标签。
sampler_p0 和 sampler_p1 采样到的索引应当互斥。
* ShardOptions 中的 drop_remainder=True 确保当数据量不能被 shard_count 整除时,会丢弃部分数据以保持各分片尽量相等(取决于实现细节)。
关于 IndexSampler 与 DataLoader 的 shard_options:
grain.DataLoader 的 shard_options 参数是为某个 JAX 进程开启分片的主要方式。DataLoader 会让其内部的采样器(即便你传入的不是分片的采样器)遵循当前进程的全局分片设置。如果你提供的 IndexSampler 已经带分片,它的分片必须与 DataLoader 的 shard_options 兼容。实际分布式场景中,为简单清晰起见,通常给 DataLoader 传递 ShardByJaxProcess() 或手动配置的 ShardOptions。
# @title Exercise 5: Student Code
# 1. Reuse DataSource and basic transformations
# TODO: Instantiate MySource (from Ex1)
# source_ex5 = ...
# YOUR CODE HERE
source_ex5 = None # Replace this
# TODO: Define basic_transformations_ex5 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# basic_transformations_ex5 = [...]
# YOUR CODE HERE
basic_transformations_ex5 = [] # Replace this
# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1 # To make collection of all labels feasible
# 3. Simulate Process 0
# TODO: Create shard_options_p0
# shard_options_p0 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p0 = None # Replace this
# TODO: Create sampler_p0. Pass shard_options_p0 to the IndexSampler.
# sampler_p0 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p0 = None # Replace this
# TODO: Create dl_p0. Pass shard_options_p0 to the DataLoader as well.
# dl_p0 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p0 = None # Replace this
labels_p0 = set()
if dl_p0:
print("--- Simulating Process 0 ---")
# YOUR CODE HERE: Iterate through dl_p0 and collect all unique original labels.
# Remember that labels might be batched. You need to iterate through items in a batch.
# For simplicity, if your MySource generates labels like idx % 10,
# you can try to collect the indices that were sampled.
# Or, more directly, collect the 'label' field from each item.
# To get original indices, you might need a transform that passes index through.
# Let's collect the 'label' values directly.
pass # Replace with iteration logic
# 4. Simulate Process 1
# TODO: Create shard_options_p1
# shard_options_p1 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p1 = None # Replace this
# TODO: Create sampler_p1
# sampler_p1 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p1 = None # Replace this
# TODO: Create dl_p1
# dl_p1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p1 = None # Replace this
labels_p1 = set()
if dl_p1:
print("\n--- Simulating Process 1 ---")
# YOUR CODE HERE: Iterate through dl_p1 and collect all unique labels.
pass # Replace with iteration logic
# 5. Verify
print(f"\n--- Verification (Total records in source: {len(source_ex5) if source_ex5 else 'N/A'}) ---")
print(f"Unique labels collected by Process 0 (count {len(labels_p0)}): sorted {sorted(list(labels_p0))[:20]}...")
print(f"Unique labels collected by Process 1 (count {len(labels_p1)}): sorted {sorted(list(labels_p1))[:20]}...")
if labels_p0 and labels_p1:
intersection = labels_p0.intersection(labels_p1)
if not intersection:
print("\nSUCCESS: No overlap in labels between Process 0 and Process 1. Sharding works as expected!")
else:
print(f"\nNOTE: Some overlap in labels found (count {len(intersection)}): {intersection}.")
print("This can happen if labels are not unique per index, or if sharding logic has issues.")
print("With MySource's label = idx % 10, an overlap in labels is expected even if indices are disjoint.")
print("A better test would be to collect original indices if possible.")
# For a more direct test of sharding of indices:
# We can define a DataSource that returns the index itself.
class IndexSource(grain.RandomAccessDataSource):
def __init__(self, num_records: int):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> int:
return idx % self._num_records # Return the index
index_source = IndexSource(num_records=100) # Smaller source for easier inspection
idx_sampler_p0 = grain.IndexSampler(len(index_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_sampler_p1 = grain.IndexSampler(len(index_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)
# DataLoader for indices (no batching, just to see raw sampled indices)
# Note: DataLoader expects dicts. Let's make IndexSource return {'index': idx}
class IndexDictSource(grain.RandomAccessDataSource):
def __init__(self, num_records: int): self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str,int]:
return {'index': idx % self._num_records}
index_dict_source = IndexDictSource(num_records=100)
# Samplers for IndexDictSource
idx_dict_sampler_p0 = grain.IndexSampler(len(index_dict_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_dict_sampler_p1 = grain.IndexSampler(len(index_dict_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)
# DataLoaders for IndexDictSource
# Pass shard_options to DataLoader as well.
if shard_options_p0 and shard_options_p1:
dl_indices_p0 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p0, worker_count=0, shard_options=shard_options_p0)
dl_indices_p1 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p1, worker_count=0, shard_options=shard_options_p1)
indices_from_p0 = {item['index'] for item in dl_indices_p0} if dl_indices_p0 else set()
indices_from_p1 = {item['index'] for item in dl_indices_p1} if dl_indices_p1 else set()
print(f"\n--- Verification of INDICES (Source size: {len(index_dict_source)}) ---")
print(f"Indices from P0 (count {len(indices_from_p0)}, shuffle=False): {sorted(list(indices_from_p0))}")
print(f"Indices from P1 (count {len(indices_from_p1)}, shuffle=False): {sorted(list(indices_from_p1))}")
if indices_from_p0 and indices_from_p1:
idx_intersection = indices_from_p0.intersection(indices_from_p1)
if not idx_intersection:
print("SUCCESS: No overlap in INDICES. Sharding of data sources works correctly!")
else:
print(f"FAILURE: Overlap in INDICES found: {idx_intersection}")
else:
print("Skipping index verification part as shard_options are not defined.")
print("\nReminder: In a real distributed setup, you'd use grain.sharding.ShardByJaxProcess() "
"and JAX would manage jax.process_index() automatically for each process.")
# @title Exercise 5: Solution
# Redefine MySource for Ex5 to include 'original_index'.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10 # Label is idx % 10
# For better sharding verification, let's also pass the original index
return {'image': image, 'label_test': label, 'original_index': effective_idx}
print("Redefined MySource for Ex5 to include 'original_index'.")
source_ex5 = MySource(num_records=1000)
# Redefine ConvertToFloat for Ex5 to include 'original_index'.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# This transform should pass through all keys it doesn't modify
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex5.")
# We will collect 'original_index' after batching, so batch must preserve it.
# grain.Batch by default collates features with the same name.
basic_transformations_ex5 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True) # drop_remainder for batching
]
print("DataSource and Transformations for Ex5 ready.")
# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1
# 3. Simulate Process 0
shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True) # drop_remainder for sharding
# Sampler for Process 0. It's important that the sampler itself is sharded.
# The DataLoader's shard_options will also apply this sharding if the sampler isn't already sharded,
# or verify consistency if it is.
sampler_p0 = grain.IndexSampler(
num_records=len(source_ex5),
shard_options=shard_options_p0, # Shard the sampler
shuffle=True, # Shuffle for more realistic scenario
num_epochs=num_epochs_for_sharding_test,
seed=common_seed
)
dl_p0 = grain.DataLoader(
data_source=source_ex5,
operations=basic_transformations_ex5,
sampler=sampler_p0,
worker_count=0, # Keep it simple for verification
shard_options=shard_options_p0, # Also inform DataLoader about the sharding context
read_options=grain.ReadOptions(num_threads=0)
)
indices_p0 = set()
if dl_p0:
print("--- Simulating Process 0 ---")
for batch in dl_p0:
indices_p0.update(batch['original_index'].tolist()) # Collect original indices
print(f"Process 0 collected {len(indices_p0)} unique indices.")
# 4. Simulate Process 1
shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)
sampler_p1 = grain.IndexSampler(
num_records=len(source_ex5),
shard_options=shard_options_p1, # Shard the sampler
shuffle=True, # Use same shuffle setting and seed for apples-to-apples comparison of sharding logic
num_epochs=num_epochs_for_sharding_test,
seed=common_seed # Same seed ensures shuffle order is same before sharding
)
dl_p1 = grain.DataLoader(
data_source=source_ex5,
operations=basic_transformations_ex5,
sampler=sampler_p1,
worker_count=0,
shard_options=shard_options_p1, # Inform DataLoader
read_options=grain.ReadOptions(num_threads=0)
)
indices_p1 = set()
if dl_p1:
print("\n--- Simulating Process 1 ---")
for batch in dl_p1:
indices_p1.update(batch['original_index'].tolist()) # Collect original indices
print(f"Process 1 collected {len(indices_p1)} unique indices.")
# 5. Verify
print(f"\n--- Verification of original_indices (Total records in source: {len(source_ex5)}) ---")
# Showing a few from each for brevity
print(f"Unique original_indices from P0 (first 20 sorted): {sorted(list(indices_p0))[:20]}...")
print(f"Unique original_indices from P1 (first 20 sorted): {sorted(list(indices_p1))[:20]}...")
expected_per_shard = len(source_ex5) // shard_count # Due to drop_remainder=True in ShardOptions
print(f"Expected records per shard (approx, due to drop_remainder in sharding): {expected_per_shard}")
print(f"Actual for P0: {len(indices_p0)}, P1: {len(indices_p1)}")
if indices_p0 and indices_p1:
intersection = indices_p0.intersection(indices_p1)
if not intersection:
print("\nSUCCESS: No overlap in original_indices between Process 0 and Process 1. Sharding works!")
else:
print(f"\nFAILURE: Overlap in original_indices found (count {len(intersection)}): {sorted(list(intersection))[:20]}...")
print("This should not happen if sharding is correct and seeds/shuffle are consistent.")
else:
print("Could not perform intersection test as one or both sets of indices are empty.")
total_unique_indices_seen = len(indices_p0.union(indices_p1))
print(f"Total unique indices seen across both simulated processes: {total_unique_indices_seen}")
# With drop_remainder=True in sharding, total might be less than len(source_ex5)
# if len(source_ex5) is not divisible by shard_count.
# Example: 1000 records, 2 shards. Each gets 500. Total 1000.
# Example: 1001 records, 2 shards. drop_remainder=True means each gets 500. Total 1000. 1 record dropped.
print("\nReminder: In a real distributed JAX application:")
print("1. Each JAX process would run this script (or similar).")
print("2. shard_options = grain.sharding.ShardByJaxProcess(drop_remainder=True) would be used.")
print("3. jax.process_index() and jax.process_count() would provide the correct shard info automatically.")
print("4. The IndexSampler and DataLoader would be configured with these auto-detected shard_options.")
---
DataLoader 如何为典型的 JAX/Flax NNX 训练循环提供数据。本练习只关注数据流,模型训练是概念性的(不做真实权重更新)。
步骤:
nnx.Module 的 SimpleNNXModel。
在 init 里初始化一个 nnx.Linear 层,输入特征应与展开后的图像尺寸一致(如 32*32*3),输出特征可为 num_classes(如 10)。记得传入 rngs 完成参数初始化。
* 实现 call(self, x):将输入 x(形状 B, H, W, C)拉平成二维后送入线性层。
train_step:@jax.jit 编译。
* 它接收 model(你的 SimpleNNXModel 实例)和来自 Grain 的 batch。
* 内部执行前向:logits = model(batch['image'])。
* 计算一个示例损失,例如 loss = jnp.mean(logits)。(本练习不需要真实损失或梯度。)
* 返回 loss 和 model。在真实场景中配合 nnx.Optimizer,优化器会就地更新模型参数,train_step 通常返回 loss、更新后的 model 以及下一步需要的优化器状态。
MySource(输出 {'image': ..., 'label': ...})、IndexSampler(如运行几个 epoch)、以及 transformations(如 ConvertToFloat、grain.Batch)。
* 实例化一个 grain.DataLoader。
SimpleNNXModel。
* 从 DataLoader 获取迭代器。
* 循环固定步数(如 100 次):
* 从迭代器获取 next_batch,如果耗尽要处理 StopIteration。
* 用当前 model 和 next_batch 调用 train_step。
* 偶尔打印这个示例损失。
# @title Exercise 6: Student Code
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# TODO: Initialize an nnx.Linear layer
# self.linear = nnx.Linear(...)
# YOUR CODE HERE
self.linear = None # Replace this
def __call__(self, x: jax.Array):
# TODO: Flatten the input image (if B, H, W, C) and pass through linear layer
# x_flat = x.reshape((x.shape[0], -1))
# return self.linear(x_flat)
# YOUR CODE HERE
return x # Replace this
# 2. Define train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
# TODO: Perform forward pass: model(batch['image'])
# logits = ...
# YOUR CODE HERE
logits = model(batch['image']) # Assuming model handles it
# TODO: Calculate a dummy loss (e.g., mean of logits)
# loss = ...
# YOUR CODE HERE
loss = jnp.array(0.0) # Replace this
# In a real scenario, you'd also compute gradients and update model parameters here.
# For this exercise, we just return the original model and the loss.
return model, loss
# 3. Set up DataLoader
# TODO: Instantiate MySource (from Ex1, or the one with 'original_index' if you prefer)
# source_ex6 = ...
# YOUR CODE HERE
source_ex6 = None # Replace this
# TODO: Instantiate an IndexSampler for a few epochs (e.g., 2 epochs)
# sampler_ex6 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex6 = None # Replace this
# TODO: Define transformations_ex6 (e.g., ConvertToFloat, grain.Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex6 = [...]
# YOUR CODE HERE
transformations_ex6 = [] # Replace this
# TODO: Instantiate data_loader_ex6
# data_loader_ex6 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex6 = None # Replace this
# 4. Write the Training Loop
if data_loader_ex6: # Proceed only if DataLoader is configured
# TODO: Initialize SimpleNNXModel
# image_height, image_width, image_channels = 32, 32, 3
# num_classes_ex6 = 10
# model_key = jax.random.key(0)
# model_ex6 = SimpleNNXModel(...)
# YOUR CODE HERE
model_ex6 = None # Replace this
if model_ex6:
# TODO: Get an iterator from data_loader_ex6
# grain_iterator_ex6 = ...
# YOUR CODE HERE
grain_iterator_ex6 = iter([]) # Replace this
num_steps = 100
print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step in range(num_steps):
try:
# TODO: Get next_batch from iterator
# next_batch = ...
# YOUR CODE HERE
next_batch = None # Replace this
if next_batch is None:
raise StopIteration # Simulate exhaustion if not implemented
except StopIteration:
print(f"DataLoader exhausted at step {step}. Ending loop.")
break
# TODO: Call train_step
# loss, model_ex6 = train_step(model_ex6, next_batch) # model_ex6 isn't actually updated here
# YOUR CODE HERE
loss = jnp.array(0.0) # Replace this
if step % 20 == 0 or step == num_steps - 1:
print(f"Step {step}: Dummy Loss = {loss.item():.4f}")
print("Conceptual training loop finished.")
else:
print("Model for Ex6 not initialized.")
else:
print("DataLoader for Ex6 not configured.")
# @title Exercise 6: Solution
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
# Assuming x is (B, H, W, C)
batch_size = x.shape[0]
x_flat = x.reshape((batch_size, -1)) # Flatten H, W, C dimensions
return self.linear(x_flat)
# 2. Define conceptual train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
# Perform forward pass
logits = model(batch['image']) # model.call is invoked
# Calculate a dummy loss
loss = jnp.mean(logits**2) # Example: mean of squared logits
# In a real training step:
# # 1. Define a loss function.
# def loss_fn(model):
# logits = model(batch['image'])
# # loss_value = ... (e.g., optax.softmax_cross_entropy_with_integer_labels)
# return jnp.mean(logits**2) # Using dummy loss from exercise
#
# # 2. Calculate gradients.
# grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)
#
# # 3. Update the model's parameters in-place using the optimizer.
# # Note: The optimizer is defined outside the train step.
# # e.g., optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
# optimizer.update(model, grads)
#
# # 4. For this exercise, we just return the original model and the loss.
return model, loss
# 3. Set up DataLoader
# Redefine MySource for Ex6.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex6.")
source_ex6 = MySource(num_records=1000)
sampler_ex6 = grain.IndexSampler(
num_records=len(source_ex6),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=2, # Run for 2 epochs
seed=42
)
# Redefine ConvertToFloat for Ex6.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex6.")
transformations_ex6 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
data_loader_ex6 = grain.DataLoader(
data_source=source_ex6,
operations=transformations_ex6,
sampler=sampler_ex6,
worker_count=2, # Use a couple of workers
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("DataLoader for Ex6 configured.")
# 4. Write the Training Loop
# Define image dimensions and number of classes
image_height, image_width, image_channels = 32, 32, 3
input_dim = image_height * image_width * image_channels
num_classes_ex6 = 10
# Initialize SimpleNNXModel
# NNX modules are typically initialized outside JIT, then their state can be passed.
# For this conceptual example, the model instance itself is passed.
model_key = jax.random.key(0)
model_ex6 = SimpleNNXModel(din=input_dim, dout=num_classes_ex6, rngs=nnx.Rngs(params=model_key))
print(f"SimpleNNXModel initialized. Input dim: {input_dim}, Output dim: {num_classes_ex6}")
# Get an iterator from data_loader_ex6
grain_iterator_ex6 = iter(data_loader_ex6)
num_steps = 100 # Total steps for the conceptual loop
print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step_idx in range(num_steps):
try:
# Get next_batch from iterator
next_batch = next(grain_iterator_ex6)
except StopIteration:
print(f"DataLoader exhausted at step {step_idx}. Ending loop.")
# Example: Re-initialize iterator if sampler allows multiple epochs
# if sampler_ex6.num_epochs is None or sampler_ex6.num_epochs > 1 (and we tracked current epoch):
# print("Re-initializing iterator for new epoch...")
# grain_iterator_ex6 = iter(data_loader_ex6)
# next_batch = next(grain_iterator_ex6)
# else:
break # Exit loop if truly exhausted
# Call train_step
# JAX arrays in batch are automatically handled by jax.jit
_, loss = train_step(model_ex6, next_batch) # model_ex6 state isn't actually updated here
if step_idx % 20 == 0 or step_idx == num_steps - 1:
print(f"Step {step_idx}: Dummy Loss = {loss.item():.4f}") # .item() to get Python scalar from JAX array
print("Conceptual training loop finished.")
---
iter(data_loader) 时,Grain 的 DataLoader 会产生一个迭代器(grain.PyGrainDatasetIterator)。该迭代器提供 get_state() 与 set_state() 方法,可以捕获迭代的内部状态(例如当前位置、采样器/变换的 RNG 状态)并在之后恢复。要完整 checkpoint 实验,这个迭代器状态应和模型参数一起保存(通常可配合 Orbax)。
步骤:
DataLoader(例如使用 MySource、num_epochs=None 且带 seed 的 IndexSampler,以及一些基础 transformations)。DataLoader 获取迭代器(iterator1)。next(iterator1) 迭代几次(如 3 个批次),并保存最后一次的批次。saved_iterator_state = iterator1.get_state()。DataLoader 实例获取新的迭代器(iterator2)。
* 恢复状态: 调用 iterator2.set_state(saved_iterator_state)。
next(iterator2) 迭代一次,得到一个批次(resumed_batch)。iterator2 得到的 resumed_batch 应当与 iterator1 在上一次批次之后本应得到的批次一致。
* 验证方式:
* 从 iterator1 拿到 saved_iterator_state 后,再调用一次 next(iterator1),得到 expected_next_batch_from_iterator1。
* 比较 resumed_batch(iterator2 调用 set_state 后得到)与 expected_next_batch_from_iterator1,两者内容(如图像数据)应一致。
# @title Exercise 7: Student Code
# 1. Set up DataLoader
# TODO: Instantiate MySource (e.g., from Ex1)
# source_ex7 = ...
# YOUR CODE HERE
source_ex7 = None # Replace this
# TODO: Instantiate an IndexSampler (num_epochs=None, seed=42)
# sampler_ex7 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex7 = None # Replace this
# TODO: Define transformations_ex7 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex7 = [...]
# YOUR CODE HERE
transformations_ex7 = [] # Replace this
# TODO: Instantiate data_loader_ex7
# data_loader_ex7 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex7 = None # Replace this
if data_loader_ex7:
# 2. Get iterator1
# TODO: iterator1 = iter(...)
# YOUR CODE HERE
iterator1 = iter([]) # Replace this
# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
try:
# TODO: last_batch_iterator1 = next(...)
# YOUR CODE HERE
last_batch_iterator1 = {} # Replace this
print(f"iterator1, batch {i+1} - first label: {last_batch_iterator1.get('label', [None])[0]}")
except StopIteration:
print("iterator1 exhausted prematurely.")
break
# 4. Save State
# TODO: saved_iterator_state = iterator1.get_state()
# YOUR CODE HERE
saved_iterator_state = None # Replace this
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")
# For verification: get the *next* batch from iterator1 *after* saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None: # Ensure state was actually saved
try:
# TODO: expected_next_batch_from_iterator1 = next(...)
# YOUR CODE HERE
expected_next_batch_from_iterator1 = {} # Replace this
print(f"Expected next batch (from iterator1 after get_state) - first label: {expected_next_batch_from_iterator1.get('label', [None])[0]}")
except StopIteration:
print("iterator1 exhausted when trying to get expected_next_batch.")
# 5. Simulate Resumption
# TODO: Get iterator2 from the same data_loader_ex7
# iterator2 = iter(...)
# YOUR CODE HERE
iterator2 = iter([]) # Replace this
if saved_iterator_state is not None:
# TODO: iterator2.set_state(...)
# YOUR CODE HERE
print("\n--- Resumed Iteration (iterator2) ---")
print("Iterator state restored to iterator2.")
else:
print("\nSkipping resumption, saved_iterator_state is None.")
# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None: # Only if state was set
try:
# TODO: resumed_batch = next(...)
# YOUR CODE HERE
resumed_batch = {} # Replace this
print(f"Resumed batch (from iterator2 after set_state) - first label: {resumed_batch.get('label', [None])[0]}")
except StopIteration:
print("iterator2 exhausted immediately after set_state.")
# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
# Compare 'image' data of the first element in the batch
# TODO: Perform comparison (e.g., np.allclose on image data)
# are_identical = np.allclose(...)
# YOUR CODE HERE
are_identical = False # Replace this
if are_identical:
print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
else:
print("\nFAILURE: Resumed batch differs from the expected next batch.")
# print(f"Expected image data (sample): {expected_next_batch_from_iterator1['image'][0,0,0,:3]}")
# print(f"Resumed image data (sample): {resumed_batch['image'][0,0,0,:3]}")
elif saved_iterator_state is not None:
print("\nVerification inconclusive: could not obtain both expected and resumed batches.")
else:
print("DataLoader for Ex7 not configured.")
# @title Exercise 7: Solution
# 1. Set up DataLoader
# Redefine MySource for Ex7.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
# Add original_index for easier verification if needed, though label might suffice
return {'image': image, 'label': label, 'original_index': effective_idx}
print("Redefined MySource for Ex7.")
source_ex7 = MySource(num_records=1000)
sampler_ex7 = grain.IndexSampler(
num_records=len(source_ex7),
shard_options=grain.NoSharding(),
shuffle=True, # Shuffling makes the test more robust
num_epochs=None, # Indefinite iteration
seed=42
)
# Redefine ConvertToFloat for Ex7.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex7.")
transformations_ex7 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
data_loader_ex7 = grain.DataLoader(
data_source=source_ex7,
operations=transformations_ex7,
sampler=sampler_ex7,
worker_count=0, # Simpler for state verification, but works with >0 too
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("DataLoader for Ex7 configured.")
# 2. Get iterator1
iterator1 = iter(data_loader_ex7)
# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
try:
last_batch_iterator1 = next(iterator1)
# Using 'original_index' for more robust check than just 'label'
print(f"iterator1, batch {i+1} - first original_index: {last_batch_iterator1['original_index'][0]}")
except StopIteration:
print("iterator1 exhausted prematurely.")
break
# 4. Save State
# Make a deep copy if you plan to continue using iterator1 and don't want
# its state object to be modified if Python passes by reference internally (usually not an issue for simple state).
# For PyGrainDatasetIterator, get_state() returns a new state object.
saved_iterator_state = iterator1.get_state()
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")
# For verification: get the next batch from iterator1 after saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None:
try:
expected_next_batch_from_iterator1 = next(iterator1)
print(f"Expected next batch (from iterator1 after get_state) - first original_index: {expected_next_batch_from_iterator1['original_index'][0]}")
except StopIteration:
print("iterator1 exhausted when trying to get expected_next_batch.")
# 5. Simulate Resumption
# Get a new iterator from the same DataLoader instance.
iterator2 = iter(data_loader_ex7)
if saved_iterator_state is not None:
iterator2.set_state(saved_iterator_state)
print("\n--- Resumed Iteration (iterator2) ---")
print("Iterator state restored to iterator2.")
else:
print("\nSkipping resumption, saved_iterator_state is None.")
# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None:
try:
resumed_batch = next(iterator2)
print(f"Resumed batch (from iterator2 after set_state) - first original_index: {resumed_batch['original_index'][0]}")
except StopIteration:
print("iterator2 exhausted immediately after set_state. This means the saved state was at the very end.")
# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
# Compare 'image' data and 'original_index' of the first element in the batch for robustness
# (Labels might repeat, indices are better for this check if available)
expected_img_sample = expected_next_batch_from_iterator1['image'][0]
resumed_img_sample = resumed_batch['image'][0]
expected_idx_sample = expected_next_batch_from_iterator1['original_index'][0]
resumed_idx_sample = resumed_batch['original_index'][0]
are_indices_identical = (expected_idx_sample == resumed_idx_sample)
are_images_identical = np.allclose(expected_img_sample, resumed_img_sample)
are_identical = are_indices_identical and are_images_identical
if are_identical:
print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
else:
print("\nFAILURE: Resumed batch differs from the expected next batch.")
if not are_indices_identical:
print(f" - Mismatch in first original_index: Expected {expected_idx_sample}, Got {resumed_idx_sample}")
if not are_images_identical:
print(f" - Mismatch in image data for first element.")
# print(f" Expected image data (sample [0,0,0]): {expected_img_sample[0,0,0]}")
# print(f" Resumed image data (sample [0,0,0]): {resumed_img_sample[0,0,0]}")
elif saved_iterator_state is not None: # If state was saved but verification couldn't complete
if expected_next_batch_from_iterator1 is None and resumed_batch is None:
print("\nVERIFICATION NOTE: Both iterators seem to be at the end of the dataset after the initial iterations. This is valid if the dataset was short.")
else:
print("\nVerification inconclusive: could not obtain both expected and resumed batches for comparison.")
print(f" expected_next_batch_from_iterator1 is None: {expected_next_batch_from_iterator1 is None}")
print(f" resumed_batch is None: {resumed_batch is None}")
else: # If DataLoader itself wasn't configured
print("DataLoader for Ex7 not configured.")
---
恭喜完成全部练习!现在你应该已经理解:
DataSource、Sampler、Operations)。grain.DataLoader 进行高效(含并行)数据输入。DataLoader(worker_count > 0) 以启用并行。RandomMapTransform 中使用提供的 rng。grain.sharding.ShardByJaxProcess(或手动 ShardOptions)进行 JAX 分片。Dropout、BatchNorm)的 checkpoint 结构已变,需要使用迁移脚本转换为 v0.11,详见 官方迁移指南。DataSource 实现或 TFDS 等库读取多种落盘格式(如 TFRecord、RecordIO)。希望这些练习能助你在 JAX 与 Grain 的学习之路上走得更远!