Efficient Data Loading with Grain: Exercises for JAX/Flax NNX

Welcome! This Colab notebook contains exercises to help you learn Google's Grain library for efficient data loading in JAX. These exercises are designed for developers familiar with PyTorch who are now exploring the JAX ecosystem, including the new Flax NNX API.

Goals of this notebook:
  • Understand the core components of Grain: DataSource, Sampler, and Operations.
  • Learn how to use grain.DataLoader for both sequential and parallel data loading.
  • Implement custom data transformations.
  • Explore data sharding for distributed training scenarios.
  • See how Grain integrates into a conceptual JAX/Flax NNX training loop.
  • Learn about checkpointing data iterator state for reproducibility.
  • Simulated Multi-Device Environment:

    To demonstrate parallelism and sharding concepts effectively in Colab (which typically provides a single CPU/GPU), this notebook starts by configuring JAX to simulate 8 CPU devices. This is achieved using XLAFLAGS and chex.setncpudevices(8).

    Let's get started! Please run the next cell to set up the environment.
# Environment Setup
# This cell configures the environment to simulate multiple CPU devices
# and installs necessary libraries.
# IMPORTANT: RUN THIS CELL FIRST. If you encounter issues with JAX device
# counts later, try 'Runtime -> Restart runtime' in the Colab menu
# and run this cell again before any others.
import os

# Configure JAX to see 8 virtual CPU devices.
# This must be done before JAX is imported for the first time in a session.
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

# Install libraries
!pip install -q jax-ai-stack==2025.9.3
print("Libraries installed.")

# Now, import chex and attempt to set_n_cpu_devices.
# This must be called after setting XLA_FLAGS and before JAX initializes its backends.
import chex

try:
  chex.set_n_cpu_devices(8)
  print("chex.set_n_cpu_devices(8) called successfully.")
except RuntimeError as e:
  print(f"Note on chex.set_n_cpu_devices: {e}")
  print("This usually means JAX was already initialized. The XLA_FLAGS environment variable should still apply.")
  print("If you see issues with device counts, ensure you 'Restart runtime' and run this cell first.")

# Verify JAX device count
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX found {jax.device_count()} devices.")
print(f"JAX devices: {jax.devices()}")

if jax.device_count() != 8:
  print("\nWARNING: JAX does not see 8 devices. Parallelism/sharding exercises might not behave as expected.")
  print("Please try 'Runtime -> Restart runtime' and run this setup cell again first.")

# Common imports for the exercises
import jax.numpy as jnp
import numpy as np
import grain.python as grain # Main Grain API
from flax import nnx # Flax NNX API
import time # For simulating work and observing performance
import copy # For checkpointing example
from typing import Dict, Any, List # For type hints
import dataclasses # For ShardOptions if needed manually
import functools # For functools.partial
print("Imports complete. Setup finished.")

Introduction to Grain

As highlighted in the lecture, JAX is incredibly fast for numerical computation, especially on accelerators. However, this speed can be bottlenecked by inefficient data loading. Standard Python data loading can struggle due to I/O limitations, CPU-bound transformations, and the Global Interpreter Lock (GIL).

Grain is Google's solution for high-performance data loading in JAX. Its key goals are:
  • Speed: Achieved through multiprocessing, shared memory, and prefetching.
  • Determinism: Ensuring reproducibility in experiments.
  • Flexibility & Simplicity: Declarative pipeline definition.
  • JAX Ecosystem Focus: Integrates with concepts like distributed sharding.
  • Conceptually, Grain's DataLoader is analogous to PyTorch's torch.utils.data.DataLoader. It orchestrates data reading, transformation, batching, and parallelization.

    Core Components of grain.DataLoader API:
    1. DataSource: Provides access to individual raw data records (must implement len and getitem).
    2. Sampler: Determines the order in which records are loaded and provides seeds for random operations, ensuring reproducibility.
    3. Operations: A list of transformations (e.g., augmentation, filtering, batching) applied sequentially to the records.

    Let's dive into the exercises!

    ---

    Exercise 1: Building Your First grain.DataLoader (Sequential)

    Goal: Get familiar with the basic components: DataSource, IndexSampler, a simple MapTransform, and grain.DataLoader running in sequential mode (worker_count=0). Instructions:
    1. Define MySource, a custom RandomAccessDataSource.
    * init: Store num_records. * len: Return num_records. * getitem: Given an idx, return a dictionary {'image': imagearray, 'label': labelint}. The image_array should be a NumPy array of shape (32, 32, 3) with dtype=np.uint8. Its values can depend on idx (e.g., np.ones(...) (idx % 255)). * The label_int should be an integer (e.g., idx % 10). * Handle potential index wrap-around for multiple epochs: idx = idx % self.numrecords.
    1. Instantiate MySource.
    2. Create an IndexSampler for shuffling, running for 1 epoch, with a fixed seed.
    3. Define a list of operations:
    * A ConvertToFloat class inheriting from grain.MapTransform that converts the 'image' to np.float32 and normalizes it to [0, 1]. * A grain.Batch operation to batch 64 items, dropping any remainder.
    1. Instantiate grain.DataLoader with worker_count=0 (for debugging/sequential mode).
    * Since MySource is in-memory, use readoptions=grain.ReadOptions(numthreads=0) to disable Grain's internal read threads.
    1. Iterate through the DataLoader to get the first batch and print its shape and the shape of its labels.
# @title Exercise 1: Student Code
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
      # TODO: Return the total number of records
      # YOUR CODE HERE
      return 0 # Replace this

  def __getitem__(self, idx: int) -> Dict[str, Any]:
      # TODO: Handle potential index wrap-around for multiple epochs
      # effective_idx = ...
      # YOUR CODE HERE
      effective_idx = idx # Replace this

      # TODO: Simulate loading data: an image and a label
      # image = np.ones(...) * (effective_idx % 255)
      # label = effective_idx % 10
      # YOUR CODE HERE
      image = np.zeros((32,32,3), dtype=np.uint8) # Replace this
      label = 0 # Replace this
      return {'image': image, 'label': label}

# 2. Instantiate MySource
# TODO: Create an instance of MySource
# source = ...
# YOUR CODE HERE
source = None # Replace this

# 3. Create an IndexSampler
# TODO: Create an IndexSampler that shuffles, runs for 1 epoch, and uses seed 42.
# num_records should be len(source).
# index_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
index_sampler = None # Replace this

# 4. Define Operations
# TODO: Define ConvertToFloat transform
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # TODO: Convert 'image' to float32 and normalize to [0, 1].
    # Keep 'label' as is.
    # YOUR CODE HERE
    image = features['image'] # Replace this
    return {'image': image.astype(np.float32) / 255.0, 'label': features['label']}

# TODO: Create a list of transformations: ConvertToFloat instance, then grain.Batch
batch_size = 64, drop_remainder = True
# transformations = [...]
# YOUR CODE HERE
transformations = [] # Replace this

# 5. Instantiate DataLoader
# TODO: Create a DataLoader with worker_count=0 and appropriate read_options
# data_loader_sequential = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_sequential = None # Replace this

# 6. Iterate and print batch info
if data_loader_sequential:
  print("DataLoader configured sequentially.")
  data_iterator_seq = iter(data_loader_sequential)
  try:
    first_batch_seq = next(data_iterator_seq)
    print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}")
    print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}")
    # Example: Check a value from the first image of the first batch
    print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
    print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
  except StopIteration:
    print("Sequential DataLoader is empty or exhausted.")
else:
  print("Sequential DataLoader not configured yet.")
# @title Exercise 1: Solution
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records # Handle wrap-around
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

# 2. Instantiate MySource
source = MySource(num_records=1000)
print(f"DataSource created with {len(source)} records.")

# 3. Create an IndexSampler
index_sampler = grain.IndexSampler(
    num_records=len(source),
    shard_options=grain.NoSharding(), # No sharding for this exercise
    shuffle=True,
    num_epochs=1, # Run for 1 epoch
    seed=42
    )
print("IndexSampler created.")

# 4. Define Operations
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # Convert 'image' to float32 and normalize to [0, 1].
    # Keep 'label' as is.
    image = features['image'].astype(np.float32) / 255.0
    return {'image': image, 'label': features['label']}

transformations = [
    ConvertToFloat(),
    grain.Batch(batch_size=64, drop_remainder=True)
    ]
print("Transformations defined.")

# 5. Instantiate DataLoader
data_loader_sequential = grain.DataLoader(
    data_source=source,
    operations=transformations,
    sampler=index_sampler,
    worker_count=0, # Sequential mode
    shard_options=grain.NoSharding(), # Explicitly no sharding for this loader instance
    read_options=grain.ReadOptions(num_threads=0) # Dataset is in-memory
)

# 6. Iterate and print batch info
if data_loader_sequential:
  print("DataLoader configured sequentially.")
  data_iterator_seq = iter(data_loader_sequential)
  try:
    first_batch_seq = next(data_iterator_seq)
    print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}") # Expected: (64, 32, 32, 3)
    print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}") # Expected: (64,)
    # Example: Check a value from the first image of the first batch
    print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
    print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
  except StopIteration:
    print("Sequential DataLoader is empty or exhausted.")
else:
  print("Sequential DataLoader not configured yet.")

---

Exercise 2: Enabling Parallelism with worker_count

Goal: Understand how worker_count > 0 enables multiprocessing for faster data loading. Instructions:
  1. Reuse MySource, IndexSampler (or create a new one if you prefer, e.g., for indefinite epochs: num_epochs=None), and transformations from Exercise 1.
  2. To better observe the potential benefits of parallelism, let's modify MySource slightly. Add a small time.sleep(0.01) (10 milliseconds) inside getitem to simulate some I/O or CPU work for each item.
  3. Instantiate a new grain.DataLoader (e.g., dataloaderparallel). This time, set worker_count to a value greater than 0 (e.g., 2 or 4). Remember our environment is faking 8 CPUs.
  4. Iterate to get the first batch and print its shape info.
  5. (Optional) Time how long it takes to get, for example, 10 batches from the sequential loader vs. the parallel loader. You should see a speed-up with the parallel loader, especially with the added time.sleep.
A note on pickling: When worker_count > 0, Grain uses multiprocessing. This means all components (DataSource, Sampler, Operations, and custom transform instances) must be picklable by Python's pickle module. Simple classes and functions are usually fine, but avoid complex closures or unpicklable objects in your transform logic.
# @title Exercise 2: Student Code
# 1. Reuse/Recreate components (DataSource with simulated work, Sampler, Operations)
# TODO: Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    # YOUR CODE HERE
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    # TODO: Add time.sleep(0.01) to simulate work
    # YOUR CODE HERE

    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

# TODO: Instantiate MySourceWithWork
# source_with_work = ...
# YOUR CODE HERE
source_with_work = None # Replace this

# TODO: Create a new IndexSampler (e.g., for indefinite epochs, num_epochs=None)
# Or reuse the one from Ex1 if you reset it or it's for multiple epochs.
# For simplicity, let's create one for indefinite epochs.
# parallel_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
parallel_sampler = None # Replace this

# Transformations can be reused from Exercise 1
# transformations = [ConvertToFloat(), grain.Batch(batch_size=64, drop_remainder=True)]
# (Assuming ConvertToFloat is defined from Ex1 solution)

# 2. Instantiate DataLoader with worker_count > 0
# TODO: Set num_workers (e.g., 4)
# num_workers = ...
# YOUR CODE HERE
num_workers = 0 # Replace this

# TODO: Create data_loader_parallel
# data_loader_parallel = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_parallel = None # Replace this

# 3. Iterate and print batch info
if data_loader_parallel:
  print(f"DataLoader configured with worker_count={num_workers}.")
  data_iterator_parallel = iter(data_loader_parallel)
  try:
    first_batch_parallel = next(data_iterator_parallel)
    print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
    print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
  except StopIteration:
    print("Parallel DataLoader is empty or exhausted.")
else:
  print("Parallel DataLoader not configured yet.")

# 4. (Optional) Timing comparison
# Re-create sequential loader with MySourceWithWork for a fair comparison
if source_with_work and transformations and index_sampler: # index_sampler from Ex1
  data_loader_seq_with_work = grain.DataLoader(
      data_source=source_with_work,
      operations=transformations, # Reusing from Ex1
      sampler=index_sampler, # Reusing from Ex1 (ensure it's fresh or allows re-iteration)
      worker_count=0,
      shard_options=grain.NoSharding(),
      read_options=grain.ReadOptions(num_threads=0)
      )
  num_batches_to_test = 5 # Small number for quick test

if data_loader_seq_with_work:
    print(f"\nTiming test for {num_batches_to_test} batches:")
    # Sequential
    iterator_seq = iter(data_loader_seq_with_work)
    start_time = time.time()
    try:
      for i in range(num_batches_to_test):
        batch = next(iterator_seq)
        if i == 0: print(f"  Seq batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
    except StopIteration:
      print("Sequential loader exhausted early.")
    end_time = time.time()
    print(f"Sequential ({num_batches_to_test} batches) took: {end_time - start_time:.4f} seconds")

if data_loader_parallel:
    # Parallel
    # Ensure sampler is fresh for parallel loader if it was used above
    # For this optional part, let's use a fresh sampler for the parallel loader
    # to avoid StopIteration if the previous sampler was single-epoch and exhausted.
    fresh_parallel_sampler = grain.IndexSampler(
        num_records=len(source_with_work),
        shard_options=grain.NoSharding(),
        shuffle=True,
        num_epochs=None, # Indefinite
        seed=43 # Different seed or same, for this test it's about speed
    )
    data_loader_parallel_for_timing = grain.DataLoader(
        data_source=source_with_work,
        operations=transformations, # Reusing from Ex1
        sampler=fresh_parallel_sampler,
        worker_count=num_workers if num_workers > 0 else 2, # Ensure parallelism
        shard_options=grain.NoSharding(),
        read_options=grain.ReadOptions(num_threads=0)
    )
    iterator_parallel = iter(data_loader_parallel_for_timing)
    start_time = time.time()
    try:
      for i in range(num_batches_to_test):
        batch = next(iterator_parallel)
        if i == 0: print(f"  Parallel batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
    except StopIteration:
      print("Parallel loader exhausted early.")
    end_time = time.time()
    print(f"Parallel ({num_batches_to_test} batches, {num_workers if num_workers > 0 else 2} workers) took: {end_time - start_time:.4f} seconds")
else:
  print("Skipping optional timing: source_with_work, transformations, or index_sampler not defined.")
# @title Exercise 2: Solution
# 1. Reuse/Recreate components
# Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records

  def __len__(self) -> int:
    return self._num_records

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    time.sleep(0.01) # Simulate 10ms of work per item
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}

source_with_work = MySourceWithWork(num_records=1000)
print(f"MySourceWithWork created with {len(source_with_work)} records.")

# Sampler for parallel loading (indefinite epochs for robust testing)
parallel_sampler = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=None, # Run indefinitely
    seed=42
    )
print("Parallel IndexSampler created.")

# Transformations can be reused from Exercise 1 solution
# Ensure ConvertToFloat is defined (it was in Ex1 solution cell)
if 'ConvertToFloat' not in globals(): # Basic check
  class ConvertToFloat(grain.MapTransform): # Redefine if not in current scope
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']}

    print("Redefined ConvertToFloat for safety.")

transformations_ex2 = [
    ConvertToFloat(),
    grain.Batch(batch_size=64, drop_remainder=True)
    ]
print("Transformations for Ex2 ready.")

# 2. Instantiate DataLoader with worker_count > 0
num_workers = 4 # Use 4 workers; JAX is configured for 8 virtual CPUs
# Max useful workers often related to num CPU cores available.
data_loader_parallel = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=parallel_sampler,
    worker_count=num_workers,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0) # Data source simulates work but is "in-memory"
    )

# 3. Iterate and print batch info
if data_loader_parallel:
  print(f"DataLoader configured with worker_count={num_workers}.")
  data_iterator_parallel = iter(data_loader_parallel)
  try:
    first_batch_parallel = next(data_iterator_parallel)
    print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
    print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
  except StopIteration:
    print("Parallel DataLoader is empty or exhausted.")
else:
  print("Parallel DataLoader not configured yet.")

# 4. (Optional) Timing comparison
# Create a fresh IndexSampler for the sequential loader for a fair comparison start
# (num_epochs=1 to match typical test for sequential pass)
seq_sampler_for_timing = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=1, # Single epoch for this timing test
    seed=42
    )

data_loader_seq_with_work = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=seq_sampler_for_timing,
    worker_count=0,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0)
    )
num_batches_to_test = 5 # Number of batches to fetch for timing
print(f"\nTiming test for {num_batches_to_test} batches (each item has 0.01s simulated work):")

# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time_seq = time.time()
try:
  for i in range(num_batches_to_test):
    batch_seq = next(iterator_seq)
    if i == 0 and num_batches_to_test > 0 : print(f" Seq batch 1 label sum: {batch_seq['label'].sum()}") # to ensure work is done
except StopIteration:
  print(f"Sequential loader exhausted before {num_batches_to_test} batches.")
end_time_seq = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time_seq - start_time_seq:.4f} seconds")

# Parallel
# Use a fresh sampler for the parallel loader for timing to ensure it's not exhausted
# and runs for enough batches.
parallel_sampler_for_timing = grain.IndexSampler(
    num_records=len(source_with_work),
    shard_options=grain.NoSharding(),
    shuffle=True,
    num_epochs=None, # Indefinite, or ensure enough for num_batches_to_test
    seed=43 # Can be same or different seed
    )

data_loader_parallel_for_timing = grain.DataLoader(
    data_source=source_with_work,
    operations=transformations_ex2,
    sampler=parallel_sampler_for_timing,
    worker_count=num_workers,
    shard_options=grain.NoSharding(),
    read_options=grain.ReadOptions(num_threads=0)
    )

iterator_parallel_timed = iter(data_loader_parallel_for_timing)
start_time_parallel = time.time()
try:
  for i in range(num_batches_to_test):
    batch_par = next(iterator_parallel_timed)
    if i == 0 and num_batches_to_test > 0 : print(f" Parallel batch 1 label sum: {batch_par['label'].sum()}") # to ensure work is done
except StopIteration:
  print(f"Parallel loader exhausted before {num_batches_to_test} batches.")
end_time_parallel = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers} workers) took: {end_time_parallel - start_time_parallel:.4f} seconds")

if end_time_parallel - start_time_parallel < end_time_seq - start_time_seq:
  print("Parallel loading was faster, as expected!")
else:
  print("Parallel loading was not significantly faster. This might happen for very small num_batches_to_test due to overhead, or if simulated work is too little.")

---

Exercise 3: Custom Deterministic Transformations (MapTransform)

Goal: Implement a custom data transformation that behaves deterministically. Instructions:
  1. Define a custom class OneHotEncodeLabel that inherits from grain.MapTransform.
* Its init method should take num_classes. * Its map(self, features: Dict[str, Any]) method should: * Take the input features dictionary. * Convert the features['label'] (an integer) into a one-hot encoded NumPy array of type np.float32. The length of this array should be num_classes. * Update features['label'] with this new one-hot array. * Return the modified features dictionary.
  1. Reuse MySource (the one without time.sleep) and IndexSampler from Exercise 1 (or create new ones).
  2. Create a new list of operations that includes:
* An instance of your OneHotEncodeLabel (e.g., with num_classes=10, matching idx % 10 from MySource). * The ConvertToFloat transform (if not already applied to image). * grain.Batch.
  1. Instantiate a grain.DataLoader (you can use worker_count=0 or >0).
  2. Iterate to get the first batch and print the shape of the one-hot encoded labels and an example label vector.
# @title Exercise 3: Student Code
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
  def __init__(self, num_classes: int):
    # TODO: Store num_classes
    # YOUR CODE HERE
    self._num_classes = 0 # Replace this

  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    label = features['label']
    # TODO: Create one-hot encoded version of the label
    # one_hot_label = np.zeros(...)
    # one_hot_label[label] = 1.0
    # YOUR CODE HERE
    one_hot_label = np.array([label]) # Replace this

    features['label'] = one_hot_label
    return features

# 2. Reuse/Create DataSource and Sampler
# TODO: Instantiate MySource (from Ex1, no sleep)
# source_ex3 = ...
# YOUR CODE HERE
source_ex3 = None # Replace this

# TODO: Instantiate an IndexSampler (e.g., from Ex1, or a new one)
# sampler_ex3 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex3 = None # Replace this

# 3. Create new list of operations
# TODO: Instantiate OneHotEncodeLabel
num_classes_for_ohe = 10
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# YOUR CODE HERE
one_hot_encoder = None # Replace this

# TODO: Define transformations_ex3 list including one_hot_encoder,
# ConvertToFloat (if not already applied), and grain.Batch
# (Assuming ConvertToFloat is defined from Ex1 solution)
# transformations_ex3 = [...]
# YOUR CODE HERE
transformations_ex3 = [] # Replace this

# 4. Instantiate DataLoader
# TODO: Create data_loader_ex3
# data_loader_ex3 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex3 = None # Replace this

# 5. Iterate and print batch info
if data_loader_ex3:
  print("DataLoader with OneHotEncodeLabel configured.")
  iterator_ex3 = iter(data_loader_ex3)
  try:
    first_batch_ex3 = next(iterator_ex3)
    print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
    print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (batch_size, num_classes)
    if first_batch_ex3['label'].size > 0:
      print(f"Custom MapTransform - Example one-hot label: {first_batch_ex3['label'][0]}")
  except StopIteration:
    print("DataLoader for Ex3 is empty or exhausted.")
else:
  print("DataLoader for Ex3 not configured yet.")
# @title Exercise 3: Solution
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
  def __init__(self, num_classes: int):
    self._num_classes = num_classes

  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    label_scalar = features['label']
    one_hot_label = np.zeros(self._num_classes, dtype=np.float32)
    one_hot_label[label_scalar] = 1.0

    # Create a new dictionary to avoid modifying the input dict in place if it's reused
    # by other transforms or parts of the pipeline, though often direct modification is fine.
    # For safety and clarity, let's return a new dict or an updated copy.
    updated_features = features.copy()
    updated_features['label'] = one_hot_label
    return updated_features

# 2. Reuse/Create DataSource and Sampler
# Using MySource from Exercise 1 solution (no artificial sleep)
if 'MySource' not in globals(): # Basic check
  class MySource(grain.RandomAccessDataSource): # Redefine if not in current scope
    def __init__(self, num_records: int = 1000):
      self._num_records = num_records
    def __len__(self) -> int:
      return self._num_records

    def __getitem__(self, idx: int) -> Dict[str, Any]:
      effective_idx = idx % self._num_records
      image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
      label = effective_idx % 10
      return {'image': image, 'label': label}
  print("Redefined MySource for Ex3.")

source_ex3 = MySource(num_records=1000)
sampler_ex3 = grain.IndexSampler(
  num_records=len(source_ex3),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=1,
  seed=42
  )
print("DataSource and Sampler for Ex3 ready.")

# 3. Create new list of operations
num_classes_for_ohe = 10 # Matches idx % 10 in MySource
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)

# Ensure ConvertToFloat is defined
if 'ConvertToFloat' not in globals():
  class ConvertToFloat(grain.MapTransform):
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']} # Pass label through
print("Redefined ConvertToFloat for Ex3.")

transformations_ex3 = [
  ConvertToFloat(), # Apply first to have float images
  one_hot_encoder, # Then one-hot encode labels
  grain.Batch(batch_size=64, drop_remainder=True)
  ]
print("Transformations for Ex3 defined.")

# 4. Instantiate DataLoader
data_loader_ex3 = grain.DataLoader(
  data_source=source_ex3,
  operations=transformations_ex3,
  sampler=sampler_ex3,
  worker_count=0, # Can be > 0 as well, OneHotEncodeLabel is picklable
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )

# 5. Iterate and print batch info
if data_loader_ex3:
  print("DataLoader with OneHotEncodeLabel configured.")
  iterator_ex3 = iter(data_loader_ex3)
  try:
    first_batch_ex3 = next(iterator_ex3)
    print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
    print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (64, 10)
    if first_batch_ex3['label'].size > 0:
      print(f"Custom MapTransform - Example one-hot label (first item): {first_batch_ex3['label'][0]}")
    original_label_example = np.argmax(first_batch_ex3['label'][0])
    print(f"Custom MapTransform - Decoded original label (first item): {original_label_example}")
  except StopIteration:
    print("DataLoader for Ex3 is empty or exhausted.")
else:
  print("DataLoader for Ex3 not configured yet.")

---

Exercise 4: Custom Randomized Transformations (RandomMapTransform)

Goal: Implement a custom transformation that involves randomness while ensuring reproducibility using Grain's mechanisms. Instructions:
  1. Define a custom class RandomBrightnessAdjust that inherits from grain.RandomMapTransform.
* Its random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any] method should: * Take features and an rng (NumPy random number generator). * Crucially, use the provided rng for all random operations. This ensures that the same record, when processed with the same initial seed for the sampler, gets the same "random" augmentation. * Generate a random brightness factor using rng.uniform(0.7, 1.3). * Multiply the features['image'] (assuming it's already float and normalized) by this factor. * Clip the image values to stay within [0.0, 1.0] using np.clip(). * Return the modified features.
  1. Reuse MySource, IndexSampler (ensure it has a seed), and ConvertToFloat from previous exercises.
  2. Create a list of operations including ConvertToFloat, your RandomBrightnessAdjust, and grain.Batch.
  3. Instantiate two DataLoader instances (dlrun1, dlrun2) with the exact same configuration (same source, sampler instance or sampler with same seed, operations, worker_count).
  4. Iterate and get the first batch from dl_run1. Print a sample pixel value.
  5. Reset the iterator or re-create the sampler if necessary (if numepochs=1). Then, get the first batch from dlrun2. Print the same sample pixel value.
  6. Verify: The pixel values should be identical, demonstrating reproducible random augmentation.
  7. (Optional) Change the seed in the IndexSampler for dl_run2 and observe that the pixel values now differ.
# @title Exercise 4: Student Code
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
  def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
    # TODO: Ensure image is float (e.g. by placing ConvertToFloat before this in ops)
    image = features['image']

    # TODO: Generate a random brightness factor using the provided rng
    # brightness_factor = rng.uniform(...)
    # YOUR CODE HERE
    brightness_factor = 1.0 # Replace this

    # TODO: Apply brightness adjustment and clip
    # adjusted_image = np.clip(...)
    # YOUR CODE HERE
    adjusted_image = image # Replace this

    # Create a new dictionary or update a copy
    updated_features = features.copy()
    updated_features['image'] = adjusted_image
    return updated_features

# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
# TODO: Instantiate MySource (from Ex1)
# source_ex4 = ...
# YOUR CODE HERE
source_ex4 = None # Replace this

# TODO: Instantiate an IndexSampler with a seed (e.g., seed=42, num_epochs=1 or None)
# sampler_ex4_seed42 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed42 = None # Replace this

# (Assuming ConvertToFloat is defined from Ex1 solution)

# 3. Create list of operations
# TODO: Instantiate RandomBrightnessAdjust
# random_brightness_adjuster = ...
# YOUR CODE HERE
random_brightness_adjuster = None # Replace this

# TODO: Define transformations_ex4 list: ConvertToFloat, random_brightness_adjuster, grain.Batch
# transformations_ex4 = [...]
# YOUR CODE HERE
transformations_ex4 = [] # Replace this

# 4. Instantiate two DataLoaders with the same config
# TODO: Create dl_run1
# dl_run1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run1 = None # Replace this

# TODO: Create dl_run2 (using the exact same sampler instance or a new one with the same seed)
# dl_run2 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run2 = None # Replace this

# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx, H, W, C
if dl_run1:
  print("--- Run 1 (seed 42) ---")
  iterator_run1 = iter(dl_run1)
  try:
    batch1_run1 = next(iterator_run1)
    value_run1 = batch1_run1['image'][pixel_to_check]
    print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
  except StopIteration:
    print("dl_run1 exhausted.")
    value_run1 = None
else:
  print("dl_run1 not configured.")
  value_run1 = None

if dl_run2:
  print("\n--- Run 2 (seed 42, same sampler) ---")
  # If sampler_ex4_seed42 was single-epoch and already used by dl_run1,
  # dl_run2 might be empty. For robust test, ensure sampler allows re-iteration
  # or use a new sampler instance with the same seed.
  # If sampler_ex4_seed42 had num_epochs=None, iter(dl_run2) is fine.
  # If num_epochs=1, you might need to re-create sampler_ex4_seed42 for dl_run2
  # or ensure dl_run1 didn't exhaust it (e.g. by not fully iterating it).
  # For this exercise, assume sampler_ex4_seed42 can be re-used or is fresh for dl_run2.
  iterator_run2 = iter(dl_run2)
  try:
    batch1_run2 = next(iterator_run2)
    value_run2 = batch1_run2['image'][pixel_to_check]
    print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")

    # 7. Verify
    if value_run1 is not None and value_run2 is not None:
      if np.allclose(value_run1, value_run2):
        print("\nSUCCESS: Pixel values are identical. Randomness is reproducible!")
      else:
        print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}")
  except StopIteration:
    print("dl_run2 exhausted. This might happen if the sampler was single-epoch and already used.")
    value_run2 = None
else:
  print("dl_run2 not configured.")

# 8. (Optional) Test with a different seed
# TODO: Create sampler_ex4_seed100 (seed=100)
# sampler_ex4_seed100 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed100 = None # Replace this

# TODO: Create dl_run3 with sampler_ex4_seed100
# dl_run3 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run3 = None # Replace this

if dl_run3:
  print("\n--- Run 3 (seed 100) ---")
  iterator_run3 = iter(dl_run3)
  try:
    batch1_run3 = next(iterator_run3)
    value_run3 = batch1_run3['image'][pixel_to_check]
    print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
    if value_run1 is not None and not np.allclose(value_run1, value_run3):
      print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new seed.")
    elif value_run1 is not None:
      print("NOTE: Pixel values are the same as Run 1. Check seed or logic.")
  except StopIteration:
    print("dl_run3 exhausted.")
else:
  print("\nOptional part (dl_run3) not configured.")
# @title Exercise 4: Solution
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
  def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
    image = features['image'] # Assumes image is already float, e.g. from ConvertToFloat
    # Generate a random brightness factor using the provided rng
    brightness_factor = rng.uniform(0.7, 1.3)

    # Apply brightness adjustment and clip
    adjusted_image = image * brightness_factor
    adjusted_image = np.clip(adjusted_image, 0.0, 1.0)

    updated_features = features.copy()
    updated_features['image'] = adjusted_image
    return updated_features

# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
if 'MySource' not in globals(): # Basic check for MySource
  class MySource(grain.RandomAccessDataSource):
    def __init__(self, num_records: int = 1000):
      self._num_records = num_records
    def __len__(self) -> int:
      return self._num_records
    def __getitem__(self, idx: int) -> Dict[str, Any]:
      effective_idx = idx % self._num_records
      image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
      label = effective_idx % 10
      return {'image': image, 'label': label}
  print("Redefined MySource for Ex4.")
source_ex4 = MySource(num_records=1000)

# Sampler with a fixed seed. num_epochs=None allows re-iteration for multiple DataLoaders.
# If num_epochs=1, the sampler instance can only be fully iterated once.
# For this test, using num_epochs=None or re-creating the sampler for each DataLoader is safest.
# Let's use num_epochs=None to allow the same sampler instance to be used.
sampler_ex4_seed42 = grain.IndexSampler(
  num_records=len(source_ex4),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=None, # Allow indefinite iteration
  seed=42
  )
print("DataSource and Sampler (seed 42) for Ex4 ready.")

if 'ConvertToFloat' not in globals(): # Basic check for ConvertToFloat
  class ConvertToFloat(grain.MapTransform):
    def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
      image = features['image'].astype(np.float32) / 255.0
      return {'image': image, 'label': features['label']}
  print("Redefined ConvertToFloat for Ex4.")

# 3. Create list of operations
random_brightness_adjuster = RandomBrightnessAdjust()
transformations_ex4 = [
  ConvertToFloat(),
  random_brightness_adjuster,
  grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex4 defined.")

# 4. Instantiate two DataLoaders with the same config
# Using worker_count > 0 to also test picklability of RandomBrightnessAdjust
num_workers_ex4 = 2
dl_run1 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed42, # Same sampler instance
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )

dl_run2 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed42, # Same sampler instance
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print(f"DataLoaders for Run1 and Run2 created with worker_count={num_workers_ex4}.")

# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx=0, H=0, W=0, C=0
print("\n--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)

try:
  batch1_run1 = next(iterator_run1)
  value_run1 = batch1_run1['image'][pixel_to_check]
  print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
  print("dl_run1 exhausted.")
  value_run1 = None

print("\n--- Run 2 (seed 42, same sampler instance) ---")
iterator_run2 = iter(dl_run2) # Gets a new iterator from the DataLoader

try:
  batch1_run2 = next(iterator_run2)
  value_run2 = batch1_run2['image'][pixel_to_check]
  print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
  # 7. Verify
  if value_run1 is not None and value_run2 is not None:
    if np.allclose(value_run1, value_run2):
      print("\nSUCCESS: Pixel values are identical. Randomness is reproducible with the same sampler instance!")
    else:
      print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}. This shouldn't happen if sampler is re-used correctly.")
except StopIteration:
  print("dl_run2 exhausted.")

# 8. (Optional) Test with a different seed
sampler_ex4_seed100 = grain.IndexSampler(
  num_records=len(source_ex4),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=None,
  seed=100 # Different seed
  )

dl_run3 = grain.DataLoader(
  data_source=source_ex4,
  operations=transformations_ex4,
  sampler=sampler_ex4_seed100, # Sampler with different seed
  worker_count=num_workers_ex4,
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("\nDataLoader for Run3 (seed 100) created.")
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)

try:
  batch1_run3 = next(iterator_run3)
  value_run3 = batch1_run3['image'][pixel_to_check]
  print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
  if value_run1 is not None and not np.allclose(value_run1, value_run3):
    print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new sampler seed.")
  elif value_run1 is not None:
    print("NOTE: Pixel values are the same as Run 1. This is unexpected if seeds are different.")
except StopIteration:
  print("dl_run3 exhausted.")

---

Exercise 5: Data Sharding for Distributed Training

Goal: Understand how Grain handles data sharding, essential for distributed training where each JAX process needs a unique slice of data. Background: In a real distributed JAX setup, you'd have multiple Python processes. Each process would call jax.processindex() to know its ID and jax.processcount() for the total number of processes. grain.sharding.ShardByJaxProcess() is a helper that automatically uses these values.

Since we are in a single Colab notebook (simulating one JAX process, even with multiple virtual devices), we can't directly run multiple JAX processes. Instead, we will manually create grain.ShardOptions to simulate what would happen on two different processes.

Instructions:
  1. Reuse MySource and transformations (e.g., ConvertToFloat and grain.Batch) from previous exercises.
  2. Define shard_count = 2.
  3. Simulate Process 0:
* Create shardoptionsp0 = grain.ShardOptions(shardindex=0, shardcount=shardcount, dropremainder=True). * Create an IndexSampler (samplerp0) using these shardoptions_p0. Ensure it shuffles and uses a common seed (e.g., 42). * Create a DataLoader (dlp0) using this samplerp0 and the shardoptionsp0 passed to the DataLoader itself. * Iterate through dlp0 and collect all unique labels from the first few batches (or all batches if numepochs=1).
  1. Simulate Process 1:
* Create shardoptionsp1 = grain.ShardOptions(shardindex=1, shardcount=shardcount, dropremainder=True). * Create an IndexSampler (samplerp1) using shardoptionsp1 (same seed as samplerp0). * Create a DataLoader (dlp1) using samplerp1 and shardoptionsp1. * Iterate through dl_p1 and collect all unique labels.
  1. Verify:
* Print the set of unique labels obtained by "Process 0" and "Process 1". Confirm that these two sets of labels are largely distinct (they might have minor overlaps if shuffling leads to boundary items being similar by chance, but the bulk of data indices processed should be different). The key is that the indices* sampled by samplerp0 and samplerp1 should be disjoint. * The dropremainder=True in ShardOptions ensures that if the dataset size isn't perfectly divisible by shardcount, some data might be dropped to ensure shards are equal or nearly equal (depending on implementation details). Note on shard_options in IndexSampler vs DataLoader: The shardoptions argument to grain.DataLoader is the primary way to enable sharding for a JAX process. The DataLoader will then ensure its underlying sampler (even if you provide a non-sharded one) respects these global sharding options for the current JAX process. If you provide an IndexSampler that is already sharded, its sharding must be compatible with the DataLoader's shardoptions. For simplicity and clarity in distributed settings, passing ShardByJaxProcess() or manually configured ShardOptions to the DataLoader is typical.
# @title Exercise 5: Student Code
# 1. Reuse DataSource and basic transformations
# TODO: Instantiate MySource (from Ex1)
# source_ex5 = ...
# YOUR CODE HERE
source_ex5 = None # Replace this

# TODO: Define basic_transformations_ex5 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# basic_transformations_ex5 = [...]
# YOUR CODE HERE
basic_transformations_ex5 = [] # Replace this

# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1 # To make collection of all labels feasible

# 3. Simulate Process 0
# TODO: Create shard_options_p0
# shard_options_p0 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p0 = None # Replace this

# TODO: Create sampler_p0. Pass shard_options_p0 to the IndexSampler.
# sampler_p0 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p0 = None # Replace this

# TODO: Create dl_p0. Pass shard_options_p0 to the DataLoader as well.
# dl_p0 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p0 = None # Replace this

labels_p0 = set()
if dl_p0:
  print("--- Simulating Process 0 ---")
  # YOUR CODE HERE: Iterate through dl_p0 and collect all unique original labels.
  # Remember that labels might be batched. You need to iterate through items in a batch.
  # For simplicity, if your MySource generates labels like idx % 10,
  # you can try to collect the indices that were sampled.
  # Or, more directly, collect the 'label' field from each item.
  # To get original indices, you might need a transform that passes index through.
  # Let's collect the 'label' values directly.
  pass # Replace with iteration logic

# 4. Simulate Process 1
# TODO: Create shard_options_p1
# shard_options_p1 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p1 = None # Replace this

# TODO: Create sampler_p1
# sampler_p1 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p1 = None # Replace this

# TODO: Create dl_p1
# dl_p1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p1 = None # Replace this

labels_p1 = set()
if dl_p1:
  print("\n--- Simulating Process 1 ---")
  # YOUR CODE HERE: Iterate through dl_p1 and collect all unique labels.
  pass # Replace with iteration logic

# 5. Verify
print(f"\n--- Verification (Total records in source: {len(source_ex5) if source_ex5 else 'N/A'}) ---")
print(f"Unique labels collected by Process 0 (count {len(labels_p0)}): sorted {sorted(list(labels_p0))[:20]}...")
print(f"Unique labels collected by Process 1 (count {len(labels_p1)}): sorted {sorted(list(labels_p1))[:20]}...")
if labels_p0 and labels_p1:
  intersection = labels_p0.intersection(labels_p1)
  if not intersection:
    print("\nSUCCESS: No overlap in labels between Process 0 and Process 1. Sharding works as expected!")
  else:
    print(f"\nNOTE: Some overlap in labels found (count {len(intersection)}): {intersection}.")
    print("This can happen if labels are not unique per index, or if sharding logic has issues.")
    print("With MySource's label = idx % 10, an overlap in labels is expected even if indices are disjoint.")
    print("A better test would be to collect original indices if possible.")

# For a more direct test of sharding of indices:
# We can define a DataSource that returns the index itself.
class IndexSource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> int:
    return idx % self._num_records # Return the index
index_source = IndexSource(num_records=100) # Smaller source for easier inspection

idx_sampler_p0 = grain.IndexSampler(len(index_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_sampler_p1 = grain.IndexSampler(len(index_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)

# DataLoader for indices (no batching, just to see raw sampled indices)
# Note: DataLoader expects dicts. Let's make IndexSource return {'index': idx}
class IndexDictSource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int): self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str,int]:
    return {'index': idx % self._num_records}
index_dict_source = IndexDictSource(num_records=100)

# Samplers for IndexDictSource
idx_dict_sampler_p0 = grain.IndexSampler(len(index_dict_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_dict_sampler_p1 = grain.IndexSampler(len(index_dict_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)

# DataLoaders for IndexDictSource
# Pass shard_options to DataLoader as well.
if shard_options_p0 and shard_options_p1:
  dl_indices_p0 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p0, worker_count=0, shard_options=shard_options_p0)
  dl_indices_p1 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p1, worker_count=0, shard_options=shard_options_p1)
  indices_from_p0 = {item['index'] for item in dl_indices_p0} if dl_indices_p0 else set()
  indices_from_p1 = {item['index'] for item in dl_indices_p1} if dl_indices_p1 else set()

print(f"\n--- Verification of INDICES (Source size: {len(index_dict_source)}) ---")
print(f"Indices from P0 (count {len(indices_from_p0)}, shuffle=False): {sorted(list(indices_from_p0))}")
print(f"Indices from P1 (count {len(indices_from_p1)}, shuffle=False): {sorted(list(indices_from_p1))}")

if indices_from_p0 and indices_from_p1:
  idx_intersection = indices_from_p0.intersection(indices_from_p1)
  if not idx_intersection:
    print("SUCCESS: No overlap in INDICES. Sharding of data sources works correctly!")
  else:
    print(f"FAILURE: Overlap in INDICES found: {idx_intersection}")
else:
  print("Skipping index verification part as shard_options are not defined.")
  print("\nReminder: In a real distributed setup, you'd use grain.sharding.ShardByJaxProcess() "
  "and JAX would manage jax.process_index() automatically for each process.")
# @title Exercise 5: Solution
# Redefine MySource for Ex5 to include 'original_index'.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10 # Label is idx % 10
    # For better sharding verification, let's also pass the original index
    return {'image': image, 'label_test': label, 'original_index': effective_idx}

print("Redefined MySource for Ex5 to include 'original_index'.")
source_ex5 = MySource(num_records=1000)

# Redefine ConvertToFloat for Ex5 to include 'original_index'.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    # This transform should pass through all keys it doesn't modify
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex5.")

# We will collect 'original_index' after batching, so batch must preserve it.
# grain.Batch by default collates features with the same name.
basic_transformations_ex5 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True) # drop_remainder for batching
  ]
print("DataSource and Transformations for Ex5 ready.")

# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1

# 3. Simulate Process 0
shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True) # drop_remainder for sharding

# Sampler for Process 0. It's important that the sampler itself is sharded.
# The DataLoader's shard_options will also apply this sharding if the sampler isn't already sharded,
# or verify consistency if it is.
sampler_p0 = grain.IndexSampler(
  num_records=len(source_ex5),
  shard_options=shard_options_p0, # Shard the sampler
  shuffle=True, # Shuffle for more realistic scenario
  num_epochs=num_epochs_for_sharding_test,
  seed=common_seed
  )

dl_p0 = grain.DataLoader(
  data_source=source_ex5,
  operations=basic_transformations_ex5,
  sampler=sampler_p0,
  worker_count=0, # Keep it simple for verification
  shard_options=shard_options_p0, # Also inform DataLoader about the sharding context
  read_options=grain.ReadOptions(num_threads=0)
  )

indices_p0 = set()
if dl_p0:
  print("--- Simulating Process 0 ---")
  for batch in dl_p0:
    indices_p0.update(batch['original_index'].tolist()) # Collect original indices
  print(f"Process 0 collected {len(indices_p0)} unique indices.")

# 4. Simulate Process 1
shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)
sampler_p1 = grain.IndexSampler(
  num_records=len(source_ex5),
  shard_options=shard_options_p1, # Shard the sampler
  shuffle=True, # Use same shuffle setting and seed for apples-to-apples comparison of sharding logic
  num_epochs=num_epochs_for_sharding_test,
  seed=common_seed # Same seed ensures shuffle order is same before sharding
  )

dl_p1 = grain.DataLoader(
  data_source=source_ex5,
  operations=basic_transformations_ex5,
  sampler=sampler_p1,
  worker_count=0,
  shard_options=shard_options_p1, # Inform DataLoader
  read_options=grain.ReadOptions(num_threads=0)
  )

indices_p1 = set()
if dl_p1:
  print("\n--- Simulating Process 1 ---")
  for batch in dl_p1:
    indices_p1.update(batch['original_index'].tolist()) # Collect original indices
  print(f"Process 1 collected {len(indices_p1)} unique indices.")

# 5. Verify
print(f"\n--- Verification of original_indices (Total records in source: {len(source_ex5)}) ---")
# Showing a few from each for brevity
print(f"Unique original_indices from P0 (first 20 sorted): {sorted(list(indices_p0))[:20]}...")
print(f"Unique original_indices from P1 (first 20 sorted): {sorted(list(indices_p1))[:20]}...")
expected_per_shard = len(source_ex5) // shard_count # Due to drop_remainder=True in ShardOptions
print(f"Expected records per shard (approx, due to drop_remainder in sharding): {expected_per_shard}")
print(f"Actual for P0: {len(indices_p0)}, P1: {len(indices_p1)}")

if indices_p0 and indices_p1:
  intersection = indices_p0.intersection(indices_p1)
  if not intersection:
    print("\nSUCCESS: No overlap in original_indices between Process 0 and Process 1. Sharding works!")
  else:
    print(f"\nFAILURE: Overlap in original_indices found (count {len(intersection)}): {sorted(list(intersection))[:20]}...")
    print("This should not happen if sharding is correct and seeds/shuffle are consistent.")
else:
  print("Could not perform intersection test as one or both sets of indices are empty.")
  total_unique_indices_seen = len(indices_p0.union(indices_p1))
  print(f"Total unique indices seen across both simulated processes: {total_unique_indices_seen}")

# With drop_remainder=True in sharding, total might be less than len(source_ex5)
# if len(source_ex5) is not divisible by shard_count.
# Example: 1000 records, 2 shards. Each gets 500. Total 1000.
# Example: 1001 records, 2 shards. drop_remainder=True means each gets 500. Total 1000. 1 record dropped.
print("\nReminder: In a real distributed JAX application:")
print("1. Each JAX process would run this script (or similar).")
print("2. shard_options = grain.sharding.ShardByJaxProcess(drop_remainder=True) would be used.")
print("3. jax.process_index() and jax.process_count() would provide the correct shard info automatically.")
print("4. The IndexSampler and DataLoader would be configured with these auto-detected shard_options.")

---

Exercise 6: Integrating Grain with a JAX/Flax NNX (Conceptual) Loop

Goal: Understand how a Grain DataLoader feeds data into a typical JAX/Flax NNX training loop. This exercise is conceptual regarding model training (no actual weight updates) but practical in terms of data flow. Instructions:
  1. Define a Simple Flax NNX Model:
* Create a class SimpleNNXModel inheriting from nnx.Module. In init, initialize an nnx.Linear layer. The input features should match the flattened image dimensions (e.g., 3232*3), and output features can be num_classes (e.g., 10). Remember to pass rngs for parameter initialization. * Implement call(self, x): it should flatten the input image x (if it's B, H, W, C) and pass it through the linear layer.
  1. Define a Conceptual train_step:
* This JAX function should be JIT-compiled (@jax.jit). * It takes the model (your SimpleNNXModel instance) and a batch from Grain. * Inside, it performs a forward pass: logits = model(batch['image']). * It calculates a dummy loss, e.g., loss = jnp.mean(logits). (No real loss computation or gradients needed for this exercise). * It returns the loss and the model. In a real training scenario using nnx.Optimizer, the optimizer would update the model's parameters in-place. The train_step function would typically return the loss, the updated model, and the updated optimizer state to be used in the next iteration.
  1. Set up DataLoader:
* Use MySource (the one that yields {'image': ..., 'label': ...}), an IndexSampler (e.g., for a few epochs), and transformations (e.g., ConvertToFloat, grain.Batch). * Instantiate a grain.DataLoader.
  1. Write the Training Loop:
* Initialize your SimpleNNXModel with an appropriate JAX PRNG key. * Get an iterator from your DataLoader. * Loop for a fixed number of steps (e.g., 100): * Get the next_batch from the iterator. Handle StopIteration if the loader is exhausted. * Call your trainstep function with the current model and nextbatch. * Print the dummy loss occasionally.
# @title Exercise 6: Student Code
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # TODO: Initialize an nnx.Linear layer
    # self.linear = nnx.Linear(...)
    # YOUR CODE HERE
    self.linear = None # Replace this
  def __call__(self, x: jax.Array):
    # TODO: Flatten the input image (if B, H, W, C) and pass through linear layer
    # x_flat = x.reshape((x.shape[0], -1))
    # return self.linear(x_flat)
    # YOUR CODE HERE
    return x # Replace this

# 2. Define train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
  # TODO: Perform forward pass: model(batch['image'])
  # logits = ...
  # YOUR CODE HERE
  logits = model(batch['image']) # Assuming model handles it
  # TODO: Calculate a dummy loss (e.g., mean of logits)
  # loss = ...
  # YOUR CODE HERE
  loss = jnp.array(0.0) # Replace this

  # In a real scenario, you'd also compute gradients and update model parameters here.
  # For this exercise, we just return the original model and the loss.
  return model, loss

# 3. Set up DataLoader
# TODO: Instantiate MySource (from Ex1, or the one with 'original_index' if you prefer)
# source_ex6 = ...
# YOUR CODE HERE
source_ex6 = None # Replace this

# TODO: Instantiate an IndexSampler for a few epochs (e.g., 2 epochs)
# sampler_ex6 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex6 = None # Replace this

# TODO: Define transformations_ex6 (e.g., ConvertToFloat, grain.Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex6 = [...]
# YOUR CODE HERE
transformations_ex6 = [] # Replace this

# TODO: Instantiate data_loader_ex6
# data_loader_ex6 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex6 = None # Replace this

# 4. Write the Training Loop
if data_loader_ex6: # Proceed only if DataLoader is configured
  # TODO: Initialize SimpleNNXModel
  # image_height, image_width, image_channels = 32, 32, 3
  # num_classes_ex6 = 10
  # model_key = jax.random.key(0)
  # model_ex6 = SimpleNNXModel(...)
  # YOUR CODE HERE
  model_ex6 = None # Replace this

if model_ex6:
  # TODO: Get an iterator from data_loader_ex6
  # grain_iterator_ex6 = ...
  # YOUR CODE HERE
  grain_iterator_ex6 = iter([]) # Replace this

  num_steps = 100
  print(f"\nStarting conceptual training loop for {num_steps} steps...")
  for step in range(num_steps):
    try:
      # TODO: Get next_batch from iterator
      # next_batch = ...
      # YOUR CODE HERE
      next_batch = None # Replace this
      if next_batch is None:
        raise StopIteration # Simulate exhaustion if not implemented
    except StopIteration:
      print(f"DataLoader exhausted at step {step}. Ending loop.")
      break

    # TODO: Call train_step
    # loss, model_ex6 = train_step(model_ex6, next_batch) # model_ex6 isn't actually updated here
    # YOUR CODE HERE
    loss = jnp.array(0.0) # Replace this

    if step % 20 == 0 or step == num_steps - 1:
      print(f"Step {step}: Dummy Loss = {loss.item():.4f}")
      print("Conceptual training loop finished.")
    else:
      print("Model for Ex6 not initialized.")
else:
  print("DataLoader for Ex6 not configured.")
# @title Exercise 6: Solution
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dout, rngs=rngs)
  def __call__(self, x: jax.Array) -> jax.Array:
    # Assuming x is (B, H, W, C)
    batch_size = x.shape[0]
    x_flat = x.reshape((batch_size, -1)) # Flatten H, W, C dimensions
    return self.linear(x_flat)

# 2. Define conceptual train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
  # Perform forward pass
  logits = model(batch['image']) # model.call is invoked
  # Calculate a dummy loss
  loss = jnp.mean(logits**2) # Example: mean of squared logits

  # In a real training step:
  # # 1. Define a loss function.
  # def loss_fn(model):
  #   logits = model(batch['image'])
  #   # loss_value = ... (e.g., optax.softmax_cross_entropy_with_integer_labels)
  #   return jnp.mean(logits**2) # Using dummy loss from exercise
  #
  # # 2. Calculate gradients.
  # grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)
  #
  # # 3. Update the model's parameters in-place using the optimizer.
  # #    Note: The optimizer is defined outside the train step.
  # #    e.g., optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
  # optimizer.update(model, grads)
  #
  # # 4. For this exercise, we just return the original model and the loss.
  return model, loss

# 3. Set up DataLoader
# Redefine MySource for Ex6.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    return {'image': image, 'label': label}
print("Redefined MySource for Ex6.")

source_ex6 = MySource(num_records=1000)
sampler_ex6 = grain.IndexSampler(
  num_records=len(source_ex6),
  shard_options=grain.NoSharding(),
  shuffle=True,
  num_epochs=2, # Run for 2 epochs
  seed=42
  )

# Redefine ConvertToFloat for Ex6.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex6.")

transformations_ex6 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True)
  ]

data_loader_ex6 = grain.DataLoader(
  data_source=source_ex6,
  operations=transformations_ex6,
  sampler=sampler_ex6,
  worker_count=2, # Use a couple of workers
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("DataLoader for Ex6 configured.")

# 4. Write the Training Loop
# Define image dimensions and number of classes
image_height, image_width, image_channels = 32, 32, 3
input_dim = image_height * image_width * image_channels
num_classes_ex6 = 10

# Initialize SimpleNNXModel
# NNX modules are typically initialized outside JIT, then their state can be passed.
# For this conceptual example, the model instance itself is passed.
model_key = jax.random.key(0)
model_ex6 = SimpleNNXModel(din=input_dim, dout=num_classes_ex6, rngs=nnx.Rngs(params=model_key))
print(f"SimpleNNXModel initialized. Input dim: {input_dim}, Output dim: {num_classes_ex6}")

# Get an iterator from data_loader_ex6
grain_iterator_ex6 = iter(data_loader_ex6)
num_steps = 100 # Total steps for the conceptual loop

print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step_idx in range(num_steps):
  try:
    # Get next_batch from iterator
    next_batch = next(grain_iterator_ex6)
  except StopIteration:
    print(f"DataLoader exhausted at step {step_idx}. Ending loop.")
    # Example: Re-initialize iterator if sampler allows multiple epochs
    # if sampler_ex6.num_epochs is None or sampler_ex6.num_epochs > 1 (and we tracked current epoch):
      # print("Re-initializing iterator for new epoch...")
      # grain_iterator_ex6 = iter(data_loader_ex6)
      # next_batch = next(grain_iterator_ex6)
    # else:
    break # Exit loop if truly exhausted

  # Call train_step
  # JAX arrays in batch are automatically handled by jax.jit
  _, loss = train_step(model_ex6, next_batch) # model_ex6 state isn't actually updated here

  if step_idx % 20 == 0 or step_idx == num_steps - 1:
    print(f"Step {step_idx}: Dummy Loss = {loss.item():.4f}") # .item() to get Python scalar from JAX array
print("Conceptual training loop finished.")

---

Exercise 7: Checkpointing and Resuming Data Iteration

Goal: Understand how to save and restore the state of a Grain data iterator for reproducible experiments, especially when resuming long training runs. Background: Grain's DataLoader produces an iterator when you call iter(dataloader). This iterator (grain.PyGrainDatasetIterator) has getstate() and set_state() methods. These allow you to capture the internal state of the iteration (e.g., current position, RNG states for samplers/transforms) and restore it later. For full experiment checkpointing, this data iterator state should be saved alongside your model parameters (often using a library like Orbax). Instructions:
  1. Set up a DataLoader (e.g., using MySource, an IndexSampler with num_epochs=None for indefinite iteration and a seed, and some basic transformations).
  2. Get an iterator (iterator1) from this DataLoader.
  3. Iterate a few times (e.g., 3 batches) using next(iterator1) and store the last batch obtained.
  4. Save State: Call savediteratorstate = iterator1.get_state().
  5. Simulate Resumption:
Get a new iterator (iterator2) from the same* DataLoader instance. * Restore State: Call iterator2.setstate(savediterator_state).
  1. Iterate once using next(iterator2) to get a batch (resumed_batch).
  2. Verify:
The resumed_batch obtained from iterator2 should be the same as the batch that would have come after* the last batch from iterator1. * To verify this: * After getting savediteratorstate from iterator1, call next(iterator1) one more time to get the expectednextbatchfromiterator1. * Compare resumedbatch (from iterator2 after setstate) with expectednextbatchfromiterator1. Their contents (e.g., image data) should be identical.
# @title Exercise 7: Student Code
# 1. Set up DataLoader
# TODO: Instantiate MySource (e.g., from Ex1)
# source_ex7 = ...
# YOUR CODE HERE
source_ex7 = None # Replace this

# TODO: Instantiate an IndexSampler (num_epochs=None, seed=42)
# sampler_ex7 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex7 = None # Replace this

# TODO: Define transformations_ex7 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex7 = [...]
# YOUR CODE HERE
transformations_ex7 = [] # Replace this

# TODO: Instantiate data_loader_ex7
# data_loader_ex7 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex7 = None # Replace this

if data_loader_ex7:
  # 2. Get iterator1
  # TODO: iterator1 = iter(...)
  # YOUR CODE HERE
  iterator1 = iter([]) # Replace this

# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
  try:
    # TODO: last_batch_iterator1 = next(...)
    # YOUR CODE HERE
    last_batch_iterator1 = {} # Replace this
    print(f"iterator1, batch {i+1} - first label: {last_batch_iterator1.get('label', [None])[0]}")
  except StopIteration:
    print("iterator1 exhausted prematurely.")
    break

# 4. Save State
# TODO: saved_iterator_state = iterator1.get_state()
# YOUR CODE HERE
saved_iterator_state = None # Replace this
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")

# For verification: get the *next* batch from iterator1 *after* saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None: # Ensure state was actually saved
  try:
    # TODO: expected_next_batch_from_iterator1 = next(...)
    # YOUR CODE HERE
    expected_next_batch_from_iterator1 = {} # Replace this
    print(f"Expected next batch (from iterator1 after get_state) - first label: {expected_next_batch_from_iterator1.get('label', [None])[0]}")
  except StopIteration:
    print("iterator1 exhausted when trying to get expected_next_batch.")

# 5. Simulate Resumption
# TODO: Get iterator2 from the same data_loader_ex7
# iterator2 = iter(...)
# YOUR CODE HERE
iterator2 = iter([]) # Replace this

if saved_iterator_state is not None:
  # TODO: iterator2.set_state(...)
  # YOUR CODE HERE
  print("\n--- Resumed Iteration (iterator2) ---")
  print("Iterator state restored to iterator2.")
else:
  print("\nSkipping resumption, saved_iterator_state is None.")

# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None: # Only if state was set
  try:
    # TODO: resumed_batch = next(...)
    # YOUR CODE HERE
    resumed_batch = {} # Replace this
    print(f"Resumed batch (from iterator2 after set_state) - first label: {resumed_batch.get('label', [None])[0]}")
  except StopIteration:
    print("iterator2 exhausted immediately after set_state.")

# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
  # Compare 'image' data of the first element in the batch
  # TODO: Perform comparison (e.g., np.allclose on image data)
  # are_identical = np.allclose(...)
  # YOUR CODE HERE
  are_identical = False # Replace this

  if are_identical:
    print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
  else:
    print("\nFAILURE: Resumed batch differs from the expected next batch.")
    # print(f"Expected image data (sample): {expected_next_batch_from_iterator1['image'][0,0,0,:3]}")
    # print(f"Resumed image data (sample): {resumed_batch['image'][0,0,0,:3]}")
elif saved_iterator_state is not None:
  print("\nVerification inconclusive: could not obtain both expected and resumed batches.")
else:
  print("DataLoader for Ex7 not configured.")
# @title Exercise 7: Solution
# 1. Set up DataLoader
# Redefine MySource for Ex7.
class MySource(grain.RandomAccessDataSource):
  def __init__(self, num_records: int = 1000):
    self._num_records = num_records
  def __len__(self) -> int:
    return self._num_records
  def __getitem__(self, idx: int) -> Dict[str, Any]:
    effective_idx = idx % self._num_records
    image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
    label = effective_idx % 10
    # Add original_index for easier verification if needed, though label might suffice
    return {'image': image, 'label': label, 'original_index': effective_idx}
print("Redefined MySource for Ex7.")

source_ex7 = MySource(num_records=1000)
sampler_ex7 = grain.IndexSampler(
  num_records=len(source_ex7),
  shard_options=grain.NoSharding(),
  shuffle=True, # Shuffling makes the test more robust
  num_epochs=None, # Indefinite iteration
  seed=42
  )

# Redefine ConvertToFloat for Ex7.
class ConvertToFloat(grain.MapTransform):
  def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
    updated_features = features.copy()
    updated_features['image'] = features['image'].astype(np.float32) / 255.0
    return updated_features
print("Redefined ConvertToFloat for Ex7.")

transformations_ex7 = [
  ConvertToFloat(),
  grain.Batch(batch_size=64, drop_remainder=True)
  ]

data_loader_ex7 = grain.DataLoader(
  data_source=source_ex7,
  operations=transformations_ex7,
  sampler=sampler_ex7,
  worker_count=0, # Simpler for state verification, but works with >0 too
  shard_options=grain.NoSharding(),
  read_options=grain.ReadOptions(num_threads=0)
  )
print("DataLoader for Ex7 configured.")

# 2. Get iterator1
iterator1 = iter(data_loader_ex7)

# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
  try:
    last_batch_iterator1 = next(iterator1)
    # Using 'original_index' for more robust check than just 'label'
    print(f"iterator1, batch {i+1} - first original_index: {last_batch_iterator1['original_index'][0]}")
  except StopIteration:
    print("iterator1 exhausted prematurely.")
    break

# 4. Save State
# Make a deep copy if you plan to continue using iterator1 and don't want
# its state object to be modified if Python passes by reference internally (usually not an issue for simple state).
# For PyGrainDatasetIterator, get_state() returns a new state object.
saved_iterator_state = iterator1.get_state()
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")

# For verification: get the next batch from iterator1 after saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None:
  try:
    expected_next_batch_from_iterator1 = next(iterator1)
    print(f"Expected next batch (from iterator1 after get_state) - first original_index: {expected_next_batch_from_iterator1['original_index'][0]}")
  except StopIteration:
    print("iterator1 exhausted when trying to get expected_next_batch.")

# 5. Simulate Resumption
# Get a new iterator from the same DataLoader instance.
iterator2 = iter(data_loader_ex7)
if saved_iterator_state is not None:
  iterator2.set_state(saved_iterator_state)
  print("\n--- Resumed Iteration (iterator2) ---")
  print("Iterator state restored to iterator2.")
else:
  print("\nSkipping resumption, saved_iterator_state is None.")

# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None:
  try:
    resumed_batch = next(iterator2)
    print(f"Resumed batch (from iterator2 after set_state) - first original_index: {resumed_batch['original_index'][0]}")
  except StopIteration:
    print("iterator2 exhausted immediately after set_state. This means the saved state was at the very end.")

# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
  # Compare 'image' data and 'original_index' of the first element in the batch for robustness
  # (Labels might repeat, indices are better for this check if available)
  expected_img_sample = expected_next_batch_from_iterator1['image'][0]
  resumed_img_sample = resumed_batch['image'][0]
  expected_idx_sample = expected_next_batch_from_iterator1['original_index'][0]
  resumed_idx_sample = resumed_batch['original_index'][0]
  are_indices_identical = (expected_idx_sample == resumed_idx_sample)
  are_images_identical = np.allclose(expected_img_sample, resumed_img_sample)

  are_identical = are_indices_identical and are_images_identical

  if are_identical:
      print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
  else:
      print("\nFAILURE: Resumed batch differs from the expected next batch.")
      if not are_indices_identical:
          print(f"  - Mismatch in first original_index: Expected {expected_idx_sample}, Got {resumed_idx_sample}")
      if not are_images_identical:
          print(f"  - Mismatch in image data for first element.")
          # print(f"    Expected image data (sample [0,0,0]): {expected_img_sample[0,0,0]}")
          # print(f"    Resumed image data (sample [0,0,0]): {resumed_img_sample[0,0,0]}")
elif saved_iterator_state is not None: # If state was saved but verification couldn't complete
  if expected_next_batch_from_iterator1 is None and resumed_batch is None:
    print("\nVERIFICATION NOTE: Both iterators seem to be at the end of the dataset after the initial iterations. This is valid if the dataset was short.")
  else:
    print("\nVerification inconclusive: could not obtain both expected and resumed batches for comparison.")
    print(f" expected_next_batch_from_iterator1 is None: {expected_next_batch_from_iterator1 is None}")
    print(f" resumed_batch is None: {resumed_batch is None}")
else: # If DataLoader itself wasn't configured
  print("DataLoader for Ex7 not configured.")

---

Conclusion and Further Exploration

Congratulations on completing the exercises! You should now have a good understanding of:

  • The fundamental components of Grain (DataSource, Sampler, Operations).
  • How to construct and use grain.DataLoader for efficient data input, including parallel loading.
  • Implementing custom deterministic and random transformations.
  • The basics of data sharding for distributed training.
  • How Grain iterators fit into a JAX/Flax NNX training loop.
  • Saving and restoring data iterator state for reproducibility.
  • Key Takeaways (Recap from Slides):
    • Use Grain: Solves JAX data bottlenecks for better performance.
    • Boost Speed: Use DataLoader(worker_count > 0) for parallelism.
    • Ensure Reproducibility: Use samplers/seeds & RandomMapTransform's provided rng.
    • Distribute: Use grain.sharding.ShardByJaxProcess (or manual ShardOptions) for JAX sharding.
    • Save Everything: Checkpoint data iterator state (e.g., via Orbax for comprehensive checkpointing) along with your model state.
    • Further Exploration:
      • Orbax Integration: For robust checkpointing in real-world projects, explore integrating Grain with Orbax. Orbax can manage saving and loading your Grain iterator state alongside your Flax model parameters and optimizer states atomically. Note: If you are migrating a project from NNX v0.10, be aware that the checkpoint structure for models with RNGs (like Dropout or BatchNorm) has changed. You will need to use a migration script to update old checkpoints to the v0.11 format, as described in the official migration guide.
      • Different Data Sources: Explore reading from various on-disk formats (e.g., TFRecord, RecordIO) using appropriate DataSource implementations or by integrating with libraries like TFDS.
      • Performance Profiling: Use JAX's profiling tools to identify and optimize data loading bottlenecks in more complex scenarios.
      • We hope these exercises have been helpful in your journey with JAX and Grain!