Welcome! This Colab notebook contains exercises to help you learn Google's Grain library for efficient data loading in JAX. These exercises are designed for developers familiar with PyTorch who are now exploring the JAX ecosystem, including the new Flax NNX API.
Goals of this notebook:To demonstrate parallelism and sharding concepts effectively in Colab (which typically provides a single CPU/GPU), this notebook starts by configuring JAX to simulate 8 CPU devices. This is achieved using XLAFLAGS and chex.setncpudevices(8).
Let's get started! Please run the next cell to set up the environment.# Environment Setup
# This cell configures the environment to simulate multiple CPU devices
# and installs necessary libraries.
# IMPORTANT: RUN THIS CELL FIRST. If you encounter issues with JAX device
# counts later, try 'Runtime -> Restart runtime' in the Colab menu
# and run this cell again before any others.
import os
# Configure JAX to see 8 virtual CPU devices.
# This must be done before JAX is imported for the first time in a session.
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
# Install libraries
!pip install -q jax-ai-stack==2025.9.3
print("Libraries installed.")
# Now, import chex and attempt to set_n_cpu_devices.
# This must be called after setting XLA_FLAGS and before JAX initializes its backends.
import chex
try:
chex.set_n_cpu_devices(8)
print("chex.set_n_cpu_devices(8) called successfully.")
except RuntimeError as e:
print(f"Note on chex.set_n_cpu_devices: {e}")
print("This usually means JAX was already initialized. The XLA_FLAGS environment variable should still apply.")
print("If you see issues with device counts, ensure you 'Restart runtime' and run this cell first.")
# Verify JAX device count
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX found {jax.device_count()} devices.")
print(f"JAX devices: {jax.devices()}")
if jax.device_count() != 8:
print("\nWARNING: JAX does not see 8 devices. Parallelism/sharding exercises might not behave as expected.")
print("Please try 'Runtime -> Restart runtime' and run this setup cell again first.")
# Common imports for the exercises
import jax.numpy as jnp
import numpy as np
import grain.python as grain # Main Grain API
from flax import nnx # Flax NNX API
import time # For simulating work and observing performance
import copy # For checkpointing example
from typing import Dict, Any, List # For type hints
import dataclasses # For ShardOptions if needed manually
import functools # For functools.partial
print("Imports complete. Setup finished.")
As highlighted in the lecture, JAX is incredibly fast for numerical computation, especially on accelerators. However, this speed can be bottlenecked by inefficient data loading. Standard Python data loading can struggle due to I/O limitations, CPU-bound transformations, and the Global Interpreter Lock (GIL).
Grain is Google's solution for high-performance data loading in JAX. Its key goals are:Conceptually, Grain's DataLoader is analogous to PyTorch's torch.utils.data.DataLoader. It orchestrates data reading, transformation, batching, and parallelization.
grain.DataLoader API:
DataSource: Provides access to individual raw data records (must implement len and getitem).Sampler: Determines the order in which records are loaded and provides seeds for random operations, ensuring reproducibility.Operations: A list of transformations (e.g., augmentation, filtering, batching) applied sequentially to the records.Let's dive into the exercises!
---
grain.DataLoader (Sequential)DataSource, IndexSampler, a simple MapTransform, and grain.DataLoader running in sequential mode (worker_count=0).
Instructions:
MySource, a custom RandomAccessDataSource.init: Store num_records.
* len: Return num_records.
* getitem: Given an idx, return a dictionary {'image': imagearray, 'label': labelint}.
The image_array should be a NumPy array of shape (32, 32, 3) with dtype=np.uint8. Its values can depend on idx (e.g., np.ones(...) (idx % 255)).
* The label_int should be an integer (e.g., idx % 10).
* Handle potential index wrap-around for multiple epochs: idx = idx % self.numrecords.
MySource.IndexSampler for shuffling, running for 1 epoch, with a fixed seed.operations:ConvertToFloat class inheriting from grain.MapTransform that converts the 'image' to np.float32 and normalizes it to [0, 1].
* A grain.Batch operation to batch 64 items, dropping any remainder.
grain.DataLoader with worker_count=0 (for debugging/sequential mode).MySource is in-memory, use readoptions=grain.ReadOptions(numthreads=0) to disable Grain's internal read threads.
DataLoader to get the first batch and print its shape and the shape of its labels.# @title Exercise 1: Student Code
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
# TODO: Return the total number of records
# YOUR CODE HERE
return 0 # Replace this
def __getitem__(self, idx: int) -> Dict[str, Any]:
# TODO: Handle potential index wrap-around for multiple epochs
# effective_idx = ...
# YOUR CODE HERE
effective_idx = idx # Replace this
# TODO: Simulate loading data: an image and a label
# image = np.ones(...) * (effective_idx % 255)
# label = effective_idx % 10
# YOUR CODE HERE
image = np.zeros((32,32,3), dtype=np.uint8) # Replace this
label = 0 # Replace this
return {'image': image, 'label': label}
# 2. Instantiate MySource
# TODO: Create an instance of MySource
# source = ...
# YOUR CODE HERE
source = None # Replace this
# 3. Create an IndexSampler
# TODO: Create an IndexSampler that shuffles, runs for 1 epoch, and uses seed 42.
# num_records should be len(source).
# index_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
index_sampler = None # Replace this
# 4. Define Operations
# TODO: Define ConvertToFloat transform
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Convert 'image' to float32 and normalize to [0, 1].
# Keep 'label' as is.
# YOUR CODE HERE
image = features['image'] # Replace this
return {'image': image.astype(np.float32) / 255.0, 'label': features['label']}
# TODO: Create a list of transformations: ConvertToFloat instance, then grain.Batch
batch_size = 64, drop_remainder = True
# transformations = [...]
# YOUR CODE HERE
transformations = [] # Replace this
# 5. Instantiate DataLoader
# TODO: Create a DataLoader with worker_count=0 and appropriate read_options
# data_loader_sequential = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_sequential = None # Replace this
# 6. Iterate and print batch info
if data_loader_sequential:
print("DataLoader configured sequentially.")
data_iterator_seq = iter(data_loader_sequential)
try:
first_batch_seq = next(data_iterator_seq)
print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}")
print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}")
# Example: Check a value from the first image of the first batch
print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
except StopIteration:
print("Sequential DataLoader is empty or exhausted.")
else:
print("Sequential DataLoader not configured yet.")
# @title Exercise 1: Solution
# 1. Define MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records # Handle wrap-around
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
# 2. Instantiate MySource
source = MySource(num_records=1000)
print(f"DataSource created with {len(source)} records.")
# 3. Create an IndexSampler
index_sampler = grain.IndexSampler(
num_records=len(source),
shard_options=grain.NoSharding(), # No sharding for this exercise
shuffle=True,
num_epochs=1, # Run for 1 epoch
seed=42
)
print("IndexSampler created.")
# 4. Define Operations
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# Convert 'image' to float32 and normalize to [0, 1].
# Keep 'label' as is.
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
transformations = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations defined.")
# 5. Instantiate DataLoader
data_loader_sequential = grain.DataLoader(
data_source=source,
operations=transformations,
sampler=index_sampler,
worker_count=0, # Sequential mode
shard_options=grain.NoSharding(), # Explicitly no sharding for this loader instance
read_options=grain.ReadOptions(num_threads=0) # Dataset is in-memory
)
# 6. Iterate and print batch info
if data_loader_sequential:
print("DataLoader configured sequentially.")
data_iterator_seq = iter(data_loader_sequential)
try:
first_batch_seq = next(data_iterator_seq)
print(f"Sequential - First batch image shape: {first_batch_seq['image'].shape}") # Expected: (64, 32, 32, 3)
print(f"Sequential - First batch label shape: {first_batch_seq['label'].shape}") # Expected: (64,)
# Example: Check a value from the first image of the first batch
print(f"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}")
print(f"Sequential - Example label value (first item): {first_batch_seq['label'][0]}")
except StopIteration:
print("Sequential DataLoader is empty or exhausted.")
else:
print("Sequential DataLoader not configured yet.")
---
worker_countworker_count > 0 enables multiprocessing for faster data loading.
Instructions:
MySource, IndexSampler (or create a new one if you prefer, e.g., for indefinite epochs: num_epochs=None), and transformations from Exercise 1.MySource slightly. Add a small time.sleep(0.01) (10 milliseconds) inside getitem to simulate some I/O or CPU work for each item.grain.DataLoader (e.g., dataloaderparallel). This time, set worker_count to a value greater than 0 (e.g., 2 or 4). Remember our environment is faking 8 CPUs.time.sleep.worker_count > 0, Grain uses multiprocessing. This means all components (DataSource, Sampler, Operations, and custom transform instances) must be picklable by Python's pickle module. Simple classes and functions are usually fine, but avoid complex closures or unpicklable objects in your transform logic.
# @title Exercise 2: Student Code
# 1. Reuse/Recreate components (DataSource with simulated work, Sampler, Operations)
# TODO: Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
# YOUR CODE HERE
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
# TODO: Add time.sleep(0.01) to simulate work
# YOUR CODE HERE
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
# TODO: Instantiate MySourceWithWork
# source_with_work = ...
# YOUR CODE HERE
source_with_work = None # Replace this
# TODO: Create a new IndexSampler (e.g., for indefinite epochs, num_epochs=None)
# Or reuse the one from Ex1 if you reset it or it's for multiple epochs.
# For simplicity, let's create one for indefinite epochs.
# parallel_sampler = grain.IndexSampler(...)
# YOUR CODE HERE
parallel_sampler = None # Replace this
# Transformations can be reused from Exercise 1
# transformations = [ConvertToFloat(), grain.Batch(batch_size=64, drop_remainder=True)]
# (Assuming ConvertToFloat is defined from Ex1 solution)
# 2. Instantiate DataLoader with worker_count > 0
# TODO: Set num_workers (e.g., 4)
# num_workers = ...
# YOUR CODE HERE
num_workers = 0 # Replace this
# TODO: Create data_loader_parallel
# data_loader_parallel = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_parallel = None # Replace this
# 3. Iterate and print batch info
if data_loader_parallel:
print(f"DataLoader configured with worker_count={num_workers}.")
data_iterator_parallel = iter(data_loader_parallel)
try:
first_batch_parallel = next(data_iterator_parallel)
print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
except StopIteration:
print("Parallel DataLoader is empty or exhausted.")
else:
print("Parallel DataLoader not configured yet.")
# 4. (Optional) Timing comparison
# Re-create sequential loader with MySourceWithWork for a fair comparison
if source_with_work and transformations and index_sampler: # index_sampler from Ex1
data_loader_seq_with_work = grain.DataLoader(
data_source=source_with_work,
operations=transformations, # Reusing from Ex1
sampler=index_sampler, # Reusing from Ex1 (ensure it's fresh or allows re-iteration)
worker_count=0,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
num_batches_to_test = 5 # Small number for quick test
if data_loader_seq_with_work:
print(f"\nTiming test for {num_batches_to_test} batches:")
# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time = time.time()
try:
for i in range(num_batches_to_test):
batch = next(iterator_seq)
if i == 0: print(f" Seq batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
except StopIteration:
print("Sequential loader exhausted early.")
end_time = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time - start_time:.4f} seconds")
if data_loader_parallel:
# Parallel
# Ensure sampler is fresh for parallel loader if it was used above
# For this optional part, let's use a fresh sampler for the parallel loader
# to avoid StopIteration if the previous sampler was single-epoch and exhausted.
fresh_parallel_sampler = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Indefinite
seed=43 # Different seed or same, for this test it's about speed
)
data_loader_parallel_for_timing = grain.DataLoader(
data_source=source_with_work,
operations=transformations, # Reusing from Ex1
sampler=fresh_parallel_sampler,
worker_count=num_workers if num_workers > 0 else 2, # Ensure parallelism
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
iterator_parallel = iter(data_loader_parallel_for_timing)
start_time = time.time()
try:
for i in range(num_batches_to_test):
batch = next(iterator_parallel)
if i == 0: print(f" Parallel batch 1 label sum: {batch['label'].sum()}") # to ensure work is done
except StopIteration:
print("Parallel loader exhausted early.")
end_time = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers if num_workers > 0 else 2} workers) took: {end_time - start_time:.4f} seconds")
else:
print("Skipping optional timing: source_with_work, transformations, or index_sampler not defined.")
# @title Exercise 2: Solution
# 1. Reuse/Recreate components
# Define MySourceWithWork, adding time.sleep(0.01) in getitem
class MySourceWithWork(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
time.sleep(0.01) # Simulate 10ms of work per item
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
source_with_work = MySourceWithWork(num_records=1000)
print(f"MySourceWithWork created with {len(source_with_work)} records.")
# Sampler for parallel loading (indefinite epochs for robust testing)
parallel_sampler = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Run indefinitely
seed=42
)
print("Parallel IndexSampler created.")
# Transformations can be reused from Exercise 1 solution
# Ensure ConvertToFloat is defined (it was in Ex1 solution cell)
if 'ConvertToFloat' not in globals(): # Basic check
class ConvertToFloat(grain.MapTransform): # Redefine if not in current scope
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
print("Redefined ConvertToFloat for safety.")
transformations_ex2 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex2 ready.")
# 2. Instantiate DataLoader with worker_count > 0
num_workers = 4 # Use 4 workers; JAX is configured for 8 virtual CPUs
# Max useful workers often related to num CPU cores available.
data_loader_parallel = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=parallel_sampler,
worker_count=num_workers,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0) # Data source simulates work but is "in-memory"
)
# 3. Iterate and print batch info
if data_loader_parallel:
print(f"DataLoader configured with worker_count={num_workers}.")
data_iterator_parallel = iter(data_loader_parallel)
try:
first_batch_parallel = next(data_iterator_parallel)
print(f"Parallel - First batch image shape: {first_batch_parallel['image'].shape}")
print(f"Parallel - First batch label shape: {first_batch_parallel['label'].shape}")
except StopIteration:
print("Parallel DataLoader is empty or exhausted.")
else:
print("Parallel DataLoader not configured yet.")
# 4. (Optional) Timing comparison
# Create a fresh IndexSampler for the sequential loader for a fair comparison start
# (num_epochs=1 to match typical test for sequential pass)
seq_sampler_for_timing = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=1, # Single epoch for this timing test
seed=42
)
data_loader_seq_with_work = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=seq_sampler_for_timing,
worker_count=0,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
num_batches_to_test = 5 # Number of batches to fetch for timing
print(f"\nTiming test for {num_batches_to_test} batches (each item has 0.01s simulated work):")
# Sequential
iterator_seq = iter(data_loader_seq_with_work)
start_time_seq = time.time()
try:
for i in range(num_batches_to_test):
batch_seq = next(iterator_seq)
if i == 0 and num_batches_to_test > 0 : print(f" Seq batch 1 label sum: {batch_seq['label'].sum()}") # to ensure work is done
except StopIteration:
print(f"Sequential loader exhausted before {num_batches_to_test} batches.")
end_time_seq = time.time()
print(f"Sequential ({num_batches_to_test} batches) took: {end_time_seq - start_time_seq:.4f} seconds")
# Parallel
# Use a fresh sampler for the parallel loader for timing to ensure it's not exhausted
# and runs for enough batches.
parallel_sampler_for_timing = grain.IndexSampler(
num_records=len(source_with_work),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Indefinite, or ensure enough for num_batches_to_test
seed=43 # Can be same or different seed
)
data_loader_parallel_for_timing = grain.DataLoader(
data_source=source_with_work,
operations=transformations_ex2,
sampler=parallel_sampler_for_timing,
worker_count=num_workers,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
iterator_parallel_timed = iter(data_loader_parallel_for_timing)
start_time_parallel = time.time()
try:
for i in range(num_batches_to_test):
batch_par = next(iterator_parallel_timed)
if i == 0 and num_batches_to_test > 0 : print(f" Parallel batch 1 label sum: {batch_par['label'].sum()}") # to ensure work is done
except StopIteration:
print(f"Parallel loader exhausted before {num_batches_to_test} batches.")
end_time_parallel = time.time()
print(f"Parallel ({num_batches_to_test} batches, {num_workers} workers) took: {end_time_parallel - start_time_parallel:.4f} seconds")
if end_time_parallel - start_time_parallel < end_time_seq - start_time_seq:
print("Parallel loading was faster, as expected!")
else:
print("Parallel loading was not significantly faster. This might happen for very small num_batches_to_test due to overhead, or if simulated work is too little.")
---
MapTransform)OneHotEncodeLabel that inherits from grain.MapTransform.init method should take num_classes.
* Its map(self, features: Dict[str, Any]) method should:
* Take the input features dictionary.
* Convert the features['label'] (an integer) into a one-hot encoded NumPy array of type np.float32. The length of this array should be num_classes.
* Update features['label'] with this new one-hot array.
* Return the modified features dictionary.
MySource (the one without time.sleep) and IndexSampler from Exercise 1 (or create new ones).operations that includes:OneHotEncodeLabel (e.g., with num_classes=10, matching idx % 10 from MySource).
* The ConvertToFloat transform (if not already applied to image).
* grain.Batch.
grain.DataLoader (you can use worker_count=0 or >0).# @title Exercise 3: Student Code
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
def __init__(self, num_classes: int):
# TODO: Store num_classes
# YOUR CODE HERE
self._num_classes = 0 # Replace this
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
label = features['label']
# TODO: Create one-hot encoded version of the label
# one_hot_label = np.zeros(...)
# one_hot_label[label] = 1.0
# YOUR CODE HERE
one_hot_label = np.array([label]) # Replace this
features['label'] = one_hot_label
return features
# 2. Reuse/Create DataSource and Sampler
# TODO: Instantiate MySource (from Ex1, no sleep)
# source_ex3 = ...
# YOUR CODE HERE
source_ex3 = None # Replace this
# TODO: Instantiate an IndexSampler (e.g., from Ex1, or a new one)
# sampler_ex3 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex3 = None # Replace this
# 3. Create new list of operations
# TODO: Instantiate OneHotEncodeLabel
num_classes_for_ohe = 10
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# YOUR CODE HERE
one_hot_encoder = None # Replace this
# TODO: Define transformations_ex3 list including one_hot_encoder,
# ConvertToFloat (if not already applied), and grain.Batch
# (Assuming ConvertToFloat is defined from Ex1 solution)
# transformations_ex3 = [...]
# YOUR CODE HERE
transformations_ex3 = [] # Replace this
# 4. Instantiate DataLoader
# TODO: Create data_loader_ex3
# data_loader_ex3 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex3 = None # Replace this
# 5. Iterate and print batch info
if data_loader_ex3:
print("DataLoader with OneHotEncodeLabel configured.")
iterator_ex3 = iter(data_loader_ex3)
try:
first_batch_ex3 = next(iterator_ex3)
print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (batch_size, num_classes)
if first_batch_ex3['label'].size > 0:
print(f"Custom MapTransform - Example one-hot label: {first_batch_ex3['label'][0]}")
except StopIteration:
print("DataLoader for Ex3 is empty or exhausted.")
else:
print("DataLoader for Ex3 not configured yet.")
# @title Exercise 3: Solution
# 1. Define OneHotEncodeLabel
class OneHotEncodeLabel(grain.MapTransform):
def __init__(self, num_classes: int):
self._num_classes = num_classes
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
label_scalar = features['label']
one_hot_label = np.zeros(self._num_classes, dtype=np.float32)
one_hot_label[label_scalar] = 1.0
# Create a new dictionary to avoid modifying the input dict in place if it's reused
# by other transforms or parts of the pipeline, though often direct modification is fine.
# For safety and clarity, let's return a new dict or an updated copy.
updated_features = features.copy()
updated_features['label'] = one_hot_label
return updated_features
# 2. Reuse/Create DataSource and Sampler
# Using MySource from Exercise 1 solution (no artificial sleep)
if 'MySource' not in globals(): # Basic check
class MySource(grain.RandomAccessDataSource): # Redefine if not in current scope
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex3.")
source_ex3 = MySource(num_records=1000)
sampler_ex3 = grain.IndexSampler(
num_records=len(source_ex3),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=1,
seed=42
)
print("DataSource and Sampler for Ex3 ready.")
# 3. Create new list of operations
num_classes_for_ohe = 10 # Matches idx % 10 in MySource
one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)
# Ensure ConvertToFloat is defined
if 'ConvertToFloat' not in globals():
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']} # Pass label through
print("Redefined ConvertToFloat for Ex3.")
transformations_ex3 = [
ConvertToFloat(), # Apply first to have float images
one_hot_encoder, # Then one-hot encode labels
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex3 defined.")
# 4. Instantiate DataLoader
data_loader_ex3 = grain.DataLoader(
data_source=source_ex3,
operations=transformations_ex3,
sampler=sampler_ex3,
worker_count=0, # Can be > 0 as well, OneHotEncodeLabel is picklable
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
# 5. Iterate and print batch info
if data_loader_ex3:
print("DataLoader with OneHotEncodeLabel configured.")
iterator_ex3 = iter(data_loader_ex3)
try:
first_batch_ex3 = next(iterator_ex3)
print(f"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}")
print(f"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}") # Expected: (64, 10)
if first_batch_ex3['label'].size > 0:
print(f"Custom MapTransform - Example one-hot label (first item): {first_batch_ex3['label'][0]}")
original_label_example = np.argmax(first_batch_ex3['label'][0])
print(f"Custom MapTransform - Decoded original label (first item): {original_label_example}")
except StopIteration:
print("DataLoader for Ex3 is empty or exhausted.")
else:
print("DataLoader for Ex3 not configured yet.")
---
RandomMapTransform)RandomBrightnessAdjust that inherits from grain.RandomMapTransform.random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any] method should:
* Take features and an rng (NumPy random number generator).
* Crucially, use the provided rng for all random operations. This ensures that the same record, when processed with the same initial seed for the sampler, gets the same "random" augmentation.
* Generate a random brightness factor using rng.uniform(0.7, 1.3).
* Multiply the features['image'] (assuming it's already float and normalized) by this factor.
* Clip the image values to stay within [0.0, 1.0] using np.clip().
* Return the modified features.
MySource, IndexSampler (ensure it has a seed), and ConvertToFloat from previous exercises.operations including ConvertToFloat, your RandomBrightnessAdjust, and grain.Batch.DataLoader instances (dlrun1, dlrun2) with the exact same configuration (same source, sampler instance or sampler with same seed, operations, worker_count).dl_run1. Print a sample pixel value.numepochs=1). Then, get the first batch from dlrun2. Print the same sample pixel value.IndexSampler for dl_run2 and observe that the pixel values now differ.# @title Exercise 4: Student Code
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
# TODO: Ensure image is float (e.g. by placing ConvertToFloat before this in ops)
image = features['image']
# TODO: Generate a random brightness factor using the provided rng
# brightness_factor = rng.uniform(...)
# YOUR CODE HERE
brightness_factor = 1.0 # Replace this
# TODO: Apply brightness adjustment and clip
# adjusted_image = np.clip(...)
# YOUR CODE HERE
adjusted_image = image # Replace this
# Create a new dictionary or update a copy
updated_features = features.copy()
updated_features['image'] = adjusted_image
return updated_features
# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
# TODO: Instantiate MySource (from Ex1)
# source_ex4 = ...
# YOUR CODE HERE
source_ex4 = None # Replace this
# TODO: Instantiate an IndexSampler with a seed (e.g., seed=42, num_epochs=1 or None)
# sampler_ex4_seed42 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed42 = None # Replace this
# (Assuming ConvertToFloat is defined from Ex1 solution)
# 3. Create list of operations
# TODO: Instantiate RandomBrightnessAdjust
# random_brightness_adjuster = ...
# YOUR CODE HERE
random_brightness_adjuster = None # Replace this
# TODO: Define transformations_ex4 list: ConvertToFloat, random_brightness_adjuster, grain.Batch
# transformations_ex4 = [...]
# YOUR CODE HERE
transformations_ex4 = [] # Replace this
# 4. Instantiate two DataLoaders with the same config
# TODO: Create dl_run1
# dl_run1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run1 = None # Replace this
# TODO: Create dl_run2 (using the exact same sampler instance or a new one with the same seed)
# dl_run2 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run2 = None # Replace this
# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx, H, W, C
if dl_run1:
print("--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)
try:
batch1_run1 = next(iterator_run1)
value_run1 = batch1_run1['image'][pixel_to_check]
print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
print("dl_run1 exhausted.")
value_run1 = None
else:
print("dl_run1 not configured.")
value_run1 = None
if dl_run2:
print("\n--- Run 2 (seed 42, same sampler) ---")
# If sampler_ex4_seed42 was single-epoch and already used by dl_run1,
# dl_run2 might be empty. For robust test, ensure sampler allows re-iteration
# or use a new sampler instance with the same seed.
# If sampler_ex4_seed42 had num_epochs=None, iter(dl_run2) is fine.
# If num_epochs=1, you might need to re-create sampler_ex4_seed42 for dl_run2
# or ensure dl_run1 didn't exhaust it (e.g. by not fully iterating it).
# For this exercise, assume sampler_ex4_seed42 can be re-used or is fresh for dl_run2.
iterator_run2 = iter(dl_run2)
try:
batch1_run2 = next(iterator_run2)
value_run2 = batch1_run2['image'][pixel_to_check]
print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
# 7. Verify
if value_run1 is not None and value_run2 is not None:
if np.allclose(value_run1, value_run2):
print("\nSUCCESS: Pixel values are identical. Randomness is reproducible!")
else:
print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}")
except StopIteration:
print("dl_run2 exhausted. This might happen if the sampler was single-epoch and already used.")
value_run2 = None
else:
print("dl_run2 not configured.")
# 8. (Optional) Test with a different seed
# TODO: Create sampler_ex4_seed100 (seed=100)
# sampler_ex4_seed100 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex4_seed100 = None # Replace this
# TODO: Create dl_run3 with sampler_ex4_seed100
# dl_run3 = grain.DataLoader(...)
# YOUR CODE HERE
dl_run3 = None # Replace this
if dl_run3:
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)
try:
batch1_run3 = next(iterator_run3)
value_run3 = batch1_run3['image'][pixel_to_check]
print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
if value_run1 is not None and not np.allclose(value_run1, value_run3):
print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new seed.")
elif value_run1 is not None:
print("NOTE: Pixel values are the same as Run 1. Check seed or logic.")
except StopIteration:
print("dl_run3 exhausted.")
else:
print("\nOptional part (dl_run3) not configured.")
# @title Exercise 4: Solution
# 1. Define RandomBrightnessAdjust
class RandomBrightnessAdjust(grain.RandomMapTransform):
def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:
image = features['image'] # Assumes image is already float, e.g. from ConvertToFloat
# Generate a random brightness factor using the provided rng
brightness_factor = rng.uniform(0.7, 1.3)
# Apply brightness adjustment and clip
adjusted_image = image * brightness_factor
adjusted_image = np.clip(adjusted_image, 0.0, 1.0)
updated_features = features.copy()
updated_features['image'] = adjusted_image
return updated_features
# 2. Reuse/Create DataSource, Sampler, ConvertToFloat
if 'MySource' not in globals(): # Basic check for MySource
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex4.")
source_ex4 = MySource(num_records=1000)
# Sampler with a fixed seed. num_epochs=None allows re-iteration for multiple DataLoaders.
# If num_epochs=1, the sampler instance can only be fully iterated once.
# For this test, using num_epochs=None or re-creating the sampler for each DataLoader is safest.
# Let's use num_epochs=None to allow the same sampler instance to be used.
sampler_ex4_seed42 = grain.IndexSampler(
num_records=len(source_ex4),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None, # Allow indefinite iteration
seed=42
)
print("DataSource and Sampler (seed 42) for Ex4 ready.")
if 'ConvertToFloat' not in globals(): # Basic check for ConvertToFloat
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
image = features['image'].astype(np.float32) / 255.0
return {'image': image, 'label': features['label']}
print("Redefined ConvertToFloat for Ex4.")
# 3. Create list of operations
random_brightness_adjuster = RandomBrightnessAdjust()
transformations_ex4 = [
ConvertToFloat(),
random_brightness_adjuster,
grain.Batch(batch_size=64, drop_remainder=True)
]
print("Transformations for Ex4 defined.")
# 4. Instantiate two DataLoaders with the same config
# Using worker_count > 0 to also test picklability of RandomBrightnessAdjust
num_workers_ex4 = 2
dl_run1 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed42, # Same sampler instance
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
dl_run2 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed42, # Same sampler instance
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print(f"DataLoaders for Run1 and Run2 created with worker_count={num_workers_ex4}.")
# 5. & 6. Iterate and compare
pixel_to_check = (0, 0, 0, 0) # Batch_idx=0, H=0, W=0, C=0
print("\n--- Run 1 (seed 42) ---")
iterator_run1 = iter(dl_run1)
try:
batch1_run1 = next(iterator_run1)
value_run1 = batch1_run1['image'][pixel_to_check]
print(f"Run 1 - Pixel {pixel_to_check} value: {value_run1}")
except StopIteration:
print("dl_run1 exhausted.")
value_run1 = None
print("\n--- Run 2 (seed 42, same sampler instance) ---")
iterator_run2 = iter(dl_run2) # Gets a new iterator from the DataLoader
try:
batch1_run2 = next(iterator_run2)
value_run2 = batch1_run2['image'][pixel_to_check]
print(f"Run 2 - Pixel {pixel_to_check} value: {value_run2}")
# 7. Verify
if value_run1 is not None and value_run2 is not None:
if np.allclose(value_run1, value_run2):
print("\nSUCCESS: Pixel values are identical. Randomness is reproducible with the same sampler instance!")
else:
print(f"\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}. This shouldn't happen if sampler is re-used correctly.")
except StopIteration:
print("dl_run2 exhausted.")
# 8. (Optional) Test with a different seed
sampler_ex4_seed100 = grain.IndexSampler(
num_records=len(source_ex4),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=None,
seed=100 # Different seed
)
dl_run3 = grain.DataLoader(
data_source=source_ex4,
operations=transformations_ex4,
sampler=sampler_ex4_seed100, # Sampler with different seed
worker_count=num_workers_ex4,
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("\nDataLoader for Run3 (seed 100) created.")
print("\n--- Run 3 (seed 100) ---")
iterator_run3 = iter(dl_run3)
try:
batch1_run3 = next(iterator_run3)
value_run3 = batch1_run3['image'][pixel_to_check]
print(f"Run 3 - Pixel {pixel_to_check} value: {value_run3}")
if value_run1 is not None and not np.allclose(value_run1, value_run3):
print("SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new sampler seed.")
elif value_run1 is not None:
print("NOTE: Pixel values are the same as Run 1. This is unexpected if seeds are different.")
except StopIteration:
print("dl_run3 exhausted.")
---
jax.processindex() to know its ID and jax.processcount() for the total number of processes. grain.sharding.ShardByJaxProcess() is a helper that automatically uses these values.
Since we are in a single Colab notebook (simulating one JAX process, even with multiple virtual devices), we can't directly run multiple JAX processes. Instead, we will manually create grain.ShardOptions to simulate what would happen on two different processes.
MySource and transformations (e.g., ConvertToFloat and grain.Batch) from previous exercises.shard_count = 2.shardoptionsp0 = grain.ShardOptions(shardindex=0, shardcount=shardcount, dropremainder=True).
* Create an IndexSampler (samplerp0) using these shardoptions_p0. Ensure it shuffles and uses a common seed (e.g., 42).
* Create a DataLoader (dlp0) using this samplerp0 and the shardoptionsp0 passed to the DataLoader itself.
* Iterate through dlp0 and collect all unique labels from the first few batches (or all batches if numepochs=1).
shardoptionsp1 = grain.ShardOptions(shardindex=1, shardcount=shardcount, dropremainder=True).
* Create an IndexSampler (samplerp1) using shardoptionsp1 (same seed as samplerp0).
* Create a DataLoader (dlp1) using samplerp1 and shardoptionsp1.
* Iterate through dl_p1 and collect all unique labels.
samplerp0 and samplerp1 should be disjoint.
* The dropremainder=True in ShardOptions ensures that if the dataset size isn't perfectly divisible by shardcount, some data might be dropped to ensure shards are equal or nearly equal (depending on implementation details).
Note on shard_options in IndexSampler vs DataLoader:
The shardoptions argument to grain.DataLoader is the primary way to enable sharding for a JAX process. The DataLoader will then ensure its underlying sampler (even if you provide a non-sharded one) respects these global sharding options for the current JAX process. If you provide an IndexSampler that is already sharded, its sharding must be compatible with the DataLoader's shardoptions. For simplicity and clarity in distributed settings, passing ShardByJaxProcess() or manually configured ShardOptions to the DataLoader is typical.
# @title Exercise 5: Student Code
# 1. Reuse DataSource and basic transformations
# TODO: Instantiate MySource (from Ex1)
# source_ex5 = ...
# YOUR CODE HERE
source_ex5 = None # Replace this
# TODO: Define basic_transformations_ex5 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# basic_transformations_ex5 = [...]
# YOUR CODE HERE
basic_transformations_ex5 = [] # Replace this
# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1 # To make collection of all labels feasible
# 3. Simulate Process 0
# TODO: Create shard_options_p0
# shard_options_p0 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p0 = None # Replace this
# TODO: Create sampler_p0. Pass shard_options_p0 to the IndexSampler.
# sampler_p0 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p0 = None # Replace this
# TODO: Create dl_p0. Pass shard_options_p0 to the DataLoader as well.
# dl_p0 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p0 = None # Replace this
labels_p0 = set()
if dl_p0:
print("--- Simulating Process 0 ---")
# YOUR CODE HERE: Iterate through dl_p0 and collect all unique original labels.
# Remember that labels might be batched. You need to iterate through items in a batch.
# For simplicity, if your MySource generates labels like idx % 10,
# you can try to collect the indices that were sampled.
# Or, more directly, collect the 'label' field from each item.
# To get original indices, you might need a transform that passes index through.
# Let's collect the 'label' values directly.
pass # Replace with iteration logic
# 4. Simulate Process 1
# TODO: Create shard_options_p1
# shard_options_p1 = grain.ShardOptions(...)
# YOUR CODE HERE
shard_options_p1 = None # Replace this
# TODO: Create sampler_p1
# sampler_p1 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_p1 = None # Replace this
# TODO: Create dl_p1
# dl_p1 = grain.DataLoader(...)
# YOUR CODE HERE
dl_p1 = None # Replace this
labels_p1 = set()
if dl_p1:
print("\n--- Simulating Process 1 ---")
# YOUR CODE HERE: Iterate through dl_p1 and collect all unique labels.
pass # Replace with iteration logic
# 5. Verify
print(f"\n--- Verification (Total records in source: {len(source_ex5) if source_ex5 else 'N/A'}) ---")
print(f"Unique labels collected by Process 0 (count {len(labels_p0)}): sorted {sorted(list(labels_p0))[:20]}...")
print(f"Unique labels collected by Process 1 (count {len(labels_p1)}): sorted {sorted(list(labels_p1))[:20]}...")
if labels_p0 and labels_p1:
intersection = labels_p0.intersection(labels_p1)
if not intersection:
print("\nSUCCESS: No overlap in labels between Process 0 and Process 1. Sharding works as expected!")
else:
print(f"\nNOTE: Some overlap in labels found (count {len(intersection)}): {intersection}.")
print("This can happen if labels are not unique per index, or if sharding logic has issues.")
print("With MySource's label = idx % 10, an overlap in labels is expected even if indices are disjoint.")
print("A better test would be to collect original indices if possible.")
# For a more direct test of sharding of indices:
# We can define a DataSource that returns the index itself.
class IndexSource(grain.RandomAccessDataSource):
def __init__(self, num_records: int):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> int:
return idx % self._num_records # Return the index
index_source = IndexSource(num_records=100) # Smaller source for easier inspection
idx_sampler_p0 = grain.IndexSampler(len(index_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_sampler_p1 = grain.IndexSampler(len(index_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)
# DataLoader for indices (no batching, just to see raw sampled indices)
# Note: DataLoader expects dicts. Let's make IndexSource return {'index': idx}
class IndexDictSource(grain.RandomAccessDataSource):
def __init__(self, num_records: int): self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str,int]:
return {'index': idx % self._num_records}
index_dict_source = IndexDictSource(num_records=100)
# Samplers for IndexDictSource
idx_dict_sampler_p0 = grain.IndexSampler(len(index_dict_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)
idx_dict_sampler_p1 = grain.IndexSampler(len(index_dict_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)
# DataLoaders for IndexDictSource
# Pass shard_options to DataLoader as well.
if shard_options_p0 and shard_options_p1:
dl_indices_p0 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p0, worker_count=0, shard_options=shard_options_p0)
dl_indices_p1 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p1, worker_count=0, shard_options=shard_options_p1)
indices_from_p0 = {item['index'] for item in dl_indices_p0} if dl_indices_p0 else set()
indices_from_p1 = {item['index'] for item in dl_indices_p1} if dl_indices_p1 else set()
print(f"\n--- Verification of INDICES (Source size: {len(index_dict_source)}) ---")
print(f"Indices from P0 (count {len(indices_from_p0)}, shuffle=False): {sorted(list(indices_from_p0))}")
print(f"Indices from P1 (count {len(indices_from_p1)}, shuffle=False): {sorted(list(indices_from_p1))}")
if indices_from_p0 and indices_from_p1:
idx_intersection = indices_from_p0.intersection(indices_from_p1)
if not idx_intersection:
print("SUCCESS: No overlap in INDICES. Sharding of data sources works correctly!")
else:
print(f"FAILURE: Overlap in INDICES found: {idx_intersection}")
else:
print("Skipping index verification part as shard_options are not defined.")
print("\nReminder: In a real distributed setup, you'd use grain.sharding.ShardByJaxProcess() "
"and JAX would manage jax.process_index() automatically for each process.")
# @title Exercise 5: Solution
# Redefine MySource for Ex5 to include 'original_index'.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10 # Label is idx % 10
# For better sharding verification, let's also pass the original index
return {'image': image, 'label_test': label, 'original_index': effective_idx}
print("Redefined MySource for Ex5 to include 'original_index'.")
source_ex5 = MySource(num_records=1000)
# Redefine ConvertToFloat for Ex5 to include 'original_index'.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
# This transform should pass through all keys it doesn't modify
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex5.")
# We will collect 'original_index' after batching, so batch must preserve it.
# grain.Batch by default collates features with the same name.
basic_transformations_ex5 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True) # drop_remainder for batching
]
print("DataSource and Transformations for Ex5 ready.")
# 2. Define shard_count
shard_count = 2
common_seed = 42
num_epochs_for_sharding_test = 1
# 3. Simulate Process 0
shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True) # drop_remainder for sharding
# Sampler for Process 0. It's important that the sampler itself is sharded.
# The DataLoader's shard_options will also apply this sharding if the sampler isn't already sharded,
# or verify consistency if it is.
sampler_p0 = grain.IndexSampler(
num_records=len(source_ex5),
shard_options=shard_options_p0, # Shard the sampler
shuffle=True, # Shuffle for more realistic scenario
num_epochs=num_epochs_for_sharding_test,
seed=common_seed
)
dl_p0 = grain.DataLoader(
data_source=source_ex5,
operations=basic_transformations_ex5,
sampler=sampler_p0,
worker_count=0, # Keep it simple for verification
shard_options=shard_options_p0, # Also inform DataLoader about the sharding context
read_options=grain.ReadOptions(num_threads=0)
)
indices_p0 = set()
if dl_p0:
print("--- Simulating Process 0 ---")
for batch in dl_p0:
indices_p0.update(batch['original_index'].tolist()) # Collect original indices
print(f"Process 0 collected {len(indices_p0)} unique indices.")
# 4. Simulate Process 1
shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)
sampler_p1 = grain.IndexSampler(
num_records=len(source_ex5),
shard_options=shard_options_p1, # Shard the sampler
shuffle=True, # Use same shuffle setting and seed for apples-to-apples comparison of sharding logic
num_epochs=num_epochs_for_sharding_test,
seed=common_seed # Same seed ensures shuffle order is same before sharding
)
dl_p1 = grain.DataLoader(
data_source=source_ex5,
operations=basic_transformations_ex5,
sampler=sampler_p1,
worker_count=0,
shard_options=shard_options_p1, # Inform DataLoader
read_options=grain.ReadOptions(num_threads=0)
)
indices_p1 = set()
if dl_p1:
print("\n--- Simulating Process 1 ---")
for batch in dl_p1:
indices_p1.update(batch['original_index'].tolist()) # Collect original indices
print(f"Process 1 collected {len(indices_p1)} unique indices.")
# 5. Verify
print(f"\n--- Verification of original_indices (Total records in source: {len(source_ex5)}) ---")
# Showing a few from each for brevity
print(f"Unique original_indices from P0 (first 20 sorted): {sorted(list(indices_p0))[:20]}...")
print(f"Unique original_indices from P1 (first 20 sorted): {sorted(list(indices_p1))[:20]}...")
expected_per_shard = len(source_ex5) // shard_count # Due to drop_remainder=True in ShardOptions
print(f"Expected records per shard (approx, due to drop_remainder in sharding): {expected_per_shard}")
print(f"Actual for P0: {len(indices_p0)}, P1: {len(indices_p1)}")
if indices_p0 and indices_p1:
intersection = indices_p0.intersection(indices_p1)
if not intersection:
print("\nSUCCESS: No overlap in original_indices between Process 0 and Process 1. Sharding works!")
else:
print(f"\nFAILURE: Overlap in original_indices found (count {len(intersection)}): {sorted(list(intersection))[:20]}...")
print("This should not happen if sharding is correct and seeds/shuffle are consistent.")
else:
print("Could not perform intersection test as one or both sets of indices are empty.")
total_unique_indices_seen = len(indices_p0.union(indices_p1))
print(f"Total unique indices seen across both simulated processes: {total_unique_indices_seen}")
# With drop_remainder=True in sharding, total might be less than len(source_ex5)
# if len(source_ex5) is not divisible by shard_count.
# Example: 1000 records, 2 shards. Each gets 500. Total 1000.
# Example: 1001 records, 2 shards. drop_remainder=True means each gets 500. Total 1000. 1 record dropped.
print("\nReminder: In a real distributed JAX application:")
print("1. Each JAX process would run this script (or similar).")
print("2. shard_options = grain.sharding.ShardByJaxProcess(drop_remainder=True) would be used.")
print("3. jax.process_index() and jax.process_count() would provide the correct shard info automatically.")
print("4. The IndexSampler and DataLoader would be configured with these auto-detected shard_options.")
---
DataLoader feeds data into a typical JAX/Flax NNX training loop. This exercise is conceptual regarding model training (no actual weight updates) but practical in terms of data flow.
Instructions:
SimpleNNXModel inheriting from nnx.Module.
In init, initialize an nnx.Linear layer. The input features should match the flattened image dimensions (e.g., 3232*3), and output features can be num_classes (e.g., 10). Remember to pass rngs for parameter initialization.
* Implement call(self, x): it should flatten the input image x (if it's B, H, W, C) and pass it through the linear layer.
train_step:@jax.jit).
* It takes the model (your SimpleNNXModel instance) and a batch from Grain.
* Inside, it performs a forward pass: logits = model(batch['image']).
* It calculates a dummy loss, e.g., loss = jnp.mean(logits). (No real loss computation or gradients needed for this exercise).
* It returns the loss and the model. In a real training scenario using nnx.Optimizer, the optimizer would update the model's parameters in-place. The train_step function would typically return the loss, the updated model, and the updated optimizer state to be used in the next iteration.
MySource (the one that yields {'image': ..., 'label': ...}), an IndexSampler (e.g., for a few epochs), and transformations (e.g., ConvertToFloat, grain.Batch).
* Instantiate a grain.DataLoader.
SimpleNNXModel with an appropriate JAX PRNG key.
* Get an iterator from your DataLoader.
* Loop for a fixed number of steps (e.g., 100):
* Get the next_batch from the iterator. Handle StopIteration if the loader is exhausted.
* Call your trainstep function with the current model and nextbatch.
* Print the dummy loss occasionally.
# @title Exercise 6: Student Code
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
# TODO: Initialize an nnx.Linear layer
# self.linear = nnx.Linear(...)
# YOUR CODE HERE
self.linear = None # Replace this
def __call__(self, x: jax.Array):
# TODO: Flatten the input image (if B, H, W, C) and pass through linear layer
# x_flat = x.reshape((x.shape[0], -1))
# return self.linear(x_flat)
# YOUR CODE HERE
return x # Replace this
# 2. Define train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
# TODO: Perform forward pass: model(batch['image'])
# logits = ...
# YOUR CODE HERE
logits = model(batch['image']) # Assuming model handles it
# TODO: Calculate a dummy loss (e.g., mean of logits)
# loss = ...
# YOUR CODE HERE
loss = jnp.array(0.0) # Replace this
# In a real scenario, you'd also compute gradients and update model parameters here.
# For this exercise, we just return the original model and the loss.
return model, loss
# 3. Set up DataLoader
# TODO: Instantiate MySource (from Ex1, or the one with 'original_index' if you prefer)
# source_ex6 = ...
# YOUR CODE HERE
source_ex6 = None # Replace this
# TODO: Instantiate an IndexSampler for a few epochs (e.g., 2 epochs)
# sampler_ex6 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex6 = None # Replace this
# TODO: Define transformations_ex6 (e.g., ConvertToFloat, grain.Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex6 = [...]
# YOUR CODE HERE
transformations_ex6 = [] # Replace this
# TODO: Instantiate data_loader_ex6
# data_loader_ex6 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex6 = None # Replace this
# 4. Write the Training Loop
if data_loader_ex6: # Proceed only if DataLoader is configured
# TODO: Initialize SimpleNNXModel
# image_height, image_width, image_channels = 32, 32, 3
# num_classes_ex6 = 10
# model_key = jax.random.key(0)
# model_ex6 = SimpleNNXModel(...)
# YOUR CODE HERE
model_ex6 = None # Replace this
if model_ex6:
# TODO: Get an iterator from data_loader_ex6
# grain_iterator_ex6 = ...
# YOUR CODE HERE
grain_iterator_ex6 = iter([]) # Replace this
num_steps = 100
print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step in range(num_steps):
try:
# TODO: Get next_batch from iterator
# next_batch = ...
# YOUR CODE HERE
next_batch = None # Replace this
if next_batch is None:
raise StopIteration # Simulate exhaustion if not implemented
except StopIteration:
print(f"DataLoader exhausted at step {step}. Ending loop.")
break
# TODO: Call train_step
# loss, model_ex6 = train_step(model_ex6, next_batch) # model_ex6 isn't actually updated here
# YOUR CODE HERE
loss = jnp.array(0.0) # Replace this
if step % 20 == 0 or step == num_steps - 1:
print(f"Step {step}: Dummy Loss = {loss.item():.4f}")
print("Conceptual training loop finished.")
else:
print("Model for Ex6 not initialized.")
else:
print("DataLoader for Ex6 not configured.")
# @title Exercise 6: Solution
# 1. Define SimpleNNXModel
class SimpleNNXModel(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
# Assuming x is (B, H, W, C)
batch_size = x.shape[0]
x_flat = x.reshape((batch_size, -1)) # Flatten H, W, C dimensions
return self.linear(x_flat)
# 2. Define conceptual train_step
def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):
# Perform forward pass
logits = model(batch['image']) # model.call is invoked
# Calculate a dummy loss
loss = jnp.mean(logits**2) # Example: mean of squared logits
# In a real training step:
# # 1. Define a loss function.
# def loss_fn(model):
# logits = model(batch['image'])
# # loss_value = ... (e.g., optax.softmax_cross_entropy_with_integer_labels)
# return jnp.mean(logits**2) # Using dummy loss from exercise
#
# # 2. Calculate gradients.
# grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)
#
# # 3. Update the model's parameters in-place using the optimizer.
# # Note: The optimizer is defined outside the train step.
# # e.g., optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
# optimizer.update(model, grads)
#
# # 4. For this exercise, we just return the original model and the loss.
return model, loss
# 3. Set up DataLoader
# Redefine MySource for Ex6.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
return {'image': image, 'label': label}
print("Redefined MySource for Ex6.")
source_ex6 = MySource(num_records=1000)
sampler_ex6 = grain.IndexSampler(
num_records=len(source_ex6),
shard_options=grain.NoSharding(),
shuffle=True,
num_epochs=2, # Run for 2 epochs
seed=42
)
# Redefine ConvertToFloat for Ex6.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex6.")
transformations_ex6 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
data_loader_ex6 = grain.DataLoader(
data_source=source_ex6,
operations=transformations_ex6,
sampler=sampler_ex6,
worker_count=2, # Use a couple of workers
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("DataLoader for Ex6 configured.")
# 4. Write the Training Loop
# Define image dimensions and number of classes
image_height, image_width, image_channels = 32, 32, 3
input_dim = image_height * image_width * image_channels
num_classes_ex6 = 10
# Initialize SimpleNNXModel
# NNX modules are typically initialized outside JIT, then their state can be passed.
# For this conceptual example, the model instance itself is passed.
model_key = jax.random.key(0)
model_ex6 = SimpleNNXModel(din=input_dim, dout=num_classes_ex6, rngs=nnx.Rngs(params=model_key))
print(f"SimpleNNXModel initialized. Input dim: {input_dim}, Output dim: {num_classes_ex6}")
# Get an iterator from data_loader_ex6
grain_iterator_ex6 = iter(data_loader_ex6)
num_steps = 100 # Total steps for the conceptual loop
print(f"\nStarting conceptual training loop for {num_steps} steps...")
for step_idx in range(num_steps):
try:
# Get next_batch from iterator
next_batch = next(grain_iterator_ex6)
except StopIteration:
print(f"DataLoader exhausted at step {step_idx}. Ending loop.")
# Example: Re-initialize iterator if sampler allows multiple epochs
# if sampler_ex6.num_epochs is None or sampler_ex6.num_epochs > 1 (and we tracked current epoch):
# print("Re-initializing iterator for new epoch...")
# grain_iterator_ex6 = iter(data_loader_ex6)
# next_batch = next(grain_iterator_ex6)
# else:
break # Exit loop if truly exhausted
# Call train_step
# JAX arrays in batch are automatically handled by jax.jit
_, loss = train_step(model_ex6, next_batch) # model_ex6 state isn't actually updated here
if step_idx % 20 == 0 or step_idx == num_steps - 1:
print(f"Step {step_idx}: Dummy Loss = {loss.item():.4f}") # .item() to get Python scalar from JAX array
print("Conceptual training loop finished.")
---
DataLoader produces an iterator when you call iter(dataloader). This iterator (grain.PyGrainDatasetIterator) has getstate() and set_state() methods. These allow you to capture the internal state of the iteration (e.g., current position, RNG states for samplers/transforms) and restore it later. For full experiment checkpointing, this data iterator state should be saved alongside your model parameters (often using a library like Orbax).
Instructions:
DataLoader (e.g., using MySource, an IndexSampler with num_epochs=None for indefinite iteration and a seed, and some basic transformations).iterator1) from this DataLoader.next(iterator1) and store the last batch obtained.savediteratorstate = iterator1.get_state().iterator2) from the same* DataLoader instance.
* Restore State: Call iterator2.setstate(savediterator_state).
next(iterator2) to get a batch (resumed_batch).resumed_batch obtained from iterator2 should be the same as the batch that would have come after* the last batch from iterator1.
* To verify this:
* After getting savediteratorstate from iterator1, call next(iterator1) one more time to get the expectednextbatchfromiterator1.
* Compare resumedbatch (from iterator2 after setstate) with expectednextbatchfromiterator1. Their contents (e.g., image data) should be identical.
# @title Exercise 7: Student Code
# 1. Set up DataLoader
# TODO: Instantiate MySource (e.g., from Ex1)
# source_ex7 = ...
# YOUR CODE HERE
source_ex7 = None # Replace this
# TODO: Instantiate an IndexSampler (num_epochs=None, seed=42)
# sampler_ex7 = grain.IndexSampler(...)
# YOUR CODE HERE
sampler_ex7 = None # Replace this
# TODO: Define transformations_ex7 (e.g., ConvertToFloat, Batch)
# (Assuming ConvertToFloat is defined)
# transformations_ex7 = [...]
# YOUR CODE HERE
transformations_ex7 = [] # Replace this
# TODO: Instantiate data_loader_ex7
# data_loader_ex7 = grain.DataLoader(...)
# YOUR CODE HERE
data_loader_ex7 = None # Replace this
if data_loader_ex7:
# 2. Get iterator1
# TODO: iterator1 = iter(...)
# YOUR CODE HERE
iterator1 = iter([]) # Replace this
# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
try:
# TODO: last_batch_iterator1 = next(...)
# YOUR CODE HERE
last_batch_iterator1 = {} # Replace this
print(f"iterator1, batch {i+1} - first label: {last_batch_iterator1.get('label', [None])[0]}")
except StopIteration:
print("iterator1 exhausted prematurely.")
break
# 4. Save State
# TODO: saved_iterator_state = iterator1.get_state()
# YOUR CODE HERE
saved_iterator_state = None # Replace this
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")
# For verification: get the *next* batch from iterator1 *after* saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None: # Ensure state was actually saved
try:
# TODO: expected_next_batch_from_iterator1 = next(...)
# YOUR CODE HERE
expected_next_batch_from_iterator1 = {} # Replace this
print(f"Expected next batch (from iterator1 after get_state) - first label: {expected_next_batch_from_iterator1.get('label', [None])[0]}")
except StopIteration:
print("iterator1 exhausted when trying to get expected_next_batch.")
# 5. Simulate Resumption
# TODO: Get iterator2 from the same data_loader_ex7
# iterator2 = iter(...)
# YOUR CODE HERE
iterator2 = iter([]) # Replace this
if saved_iterator_state is not None:
# TODO: iterator2.set_state(...)
# YOUR CODE HERE
print("\n--- Resumed Iteration (iterator2) ---")
print("Iterator state restored to iterator2.")
else:
print("\nSkipping resumption, saved_iterator_state is None.")
# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None: # Only if state was set
try:
# TODO: resumed_batch = next(...)
# YOUR CODE HERE
resumed_batch = {} # Replace this
print(f"Resumed batch (from iterator2 after set_state) - first label: {resumed_batch.get('label', [None])[0]}")
except StopIteration:
print("iterator2 exhausted immediately after set_state.")
# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
# Compare 'image' data of the first element in the batch
# TODO: Perform comparison (e.g., np.allclose on image data)
# are_identical = np.allclose(...)
# YOUR CODE HERE
are_identical = False # Replace this
if are_identical:
print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
else:
print("\nFAILURE: Resumed batch differs from the expected next batch.")
# print(f"Expected image data (sample): {expected_next_batch_from_iterator1['image'][0,0,0,:3]}")
# print(f"Resumed image data (sample): {resumed_batch['image'][0,0,0,:3]}")
elif saved_iterator_state is not None:
print("\nVerification inconclusive: could not obtain both expected and resumed batches.")
else:
print("DataLoader for Ex7 not configured.")
# @title Exercise 7: Solution
# 1. Set up DataLoader
# Redefine MySource for Ex7.
class MySource(grain.RandomAccessDataSource):
def __init__(self, num_records: int = 1000):
self._num_records = num_records
def __len__(self) -> int:
return self._num_records
def __getitem__(self, idx: int) -> Dict[str, Any]:
effective_idx = idx % self._num_records
image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)
label = effective_idx % 10
# Add original_index for easier verification if needed, though label might suffice
return {'image': image, 'label': label, 'original_index': effective_idx}
print("Redefined MySource for Ex7.")
source_ex7 = MySource(num_records=1000)
sampler_ex7 = grain.IndexSampler(
num_records=len(source_ex7),
shard_options=grain.NoSharding(),
shuffle=True, # Shuffling makes the test more robust
num_epochs=None, # Indefinite iteration
seed=42
)
# Redefine ConvertToFloat for Ex7.
class ConvertToFloat(grain.MapTransform):
def map(self, features: Dict[str, Any]) -> Dict[str, Any]:
updated_features = features.copy()
updated_features['image'] = features['image'].astype(np.float32) / 255.0
return updated_features
print("Redefined ConvertToFloat for Ex7.")
transformations_ex7 = [
ConvertToFloat(),
grain.Batch(batch_size=64, drop_remainder=True)
]
data_loader_ex7 = grain.DataLoader(
data_source=source_ex7,
operations=transformations_ex7,
sampler=sampler_ex7,
worker_count=0, # Simpler for state verification, but works with >0 too
shard_options=grain.NoSharding(),
read_options=grain.ReadOptions(num_threads=0)
)
print("DataLoader for Ex7 configured.")
# 2. Get iterator1
iterator1 = iter(data_loader_ex7)
# 3. Iterate a few times
num_initial_iterations = 3
print(f"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---")
last_batch_iterator1 = None
for i in range(num_initial_iterations):
try:
last_batch_iterator1 = next(iterator1)
# Using 'original_index' for more robust check than just 'label'
print(f"iterator1, batch {i+1} - first original_index: {last_batch_iterator1['original_index'][0]}")
except StopIteration:
print("iterator1 exhausted prematurely.")
break
# 4. Save State
# Make a deep copy if you plan to continue using iterator1 and don't want
# its state object to be modified if Python passes by reference internally (usually not an issue for simple state).
# For PyGrainDatasetIterator, get_state() returns a new state object.
saved_iterator_state = iterator1.get_state()
print(f"\nIterator state saved. Type: {type(saved_iterator_state)}")
# For verification: get the next batch from iterator1 after saving state
expected_next_batch_from_iterator1 = None
if saved_iterator_state is not None:
try:
expected_next_batch_from_iterator1 = next(iterator1)
print(f"Expected next batch (from iterator1 after get_state) - first original_index: {expected_next_batch_from_iterator1['original_index'][0]}")
except StopIteration:
print("iterator1 exhausted when trying to get expected_next_batch.")
# 5. Simulate Resumption
# Get a new iterator from the same DataLoader instance.
iterator2 = iter(data_loader_ex7)
if saved_iterator_state is not None:
iterator2.set_state(saved_iterator_state)
print("\n--- Resumed Iteration (iterator2) ---")
print("Iterator state restored to iterator2.")
else:
print("\nSkipping resumption, saved_iterator_state is None.")
# 6. Iterate once from iterator2
resumed_batch = None
if saved_iterator_state is not None:
try:
resumed_batch = next(iterator2)
print(f"Resumed batch (from iterator2 after set_state) - first original_index: {resumed_batch['original_index'][0]}")
except StopIteration:
print("iterator2 exhausted immediately after set_state. This means the saved state was at the very end.")
# 7. Verify
if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:
# Compare 'image' data and 'original_index' of the first element in the batch for robustness
# (Labels might repeat, indices are better for this check if available)
expected_img_sample = expected_next_batch_from_iterator1['image'][0]
resumed_img_sample = resumed_batch['image'][0]
expected_idx_sample = expected_next_batch_from_iterator1['original_index'][0]
resumed_idx_sample = resumed_batch['original_index'][0]
are_indices_identical = (expected_idx_sample == resumed_idx_sample)
are_images_identical = np.allclose(expected_img_sample, resumed_img_sample)
are_identical = are_indices_identical and are_images_identical
if are_identical:
print("\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!")
else:
print("\nFAILURE: Resumed batch differs from the expected next batch.")
if not are_indices_identical:
print(f" - Mismatch in first original_index: Expected {expected_idx_sample}, Got {resumed_idx_sample}")
if not are_images_identical:
print(f" - Mismatch in image data for first element.")
# print(f" Expected image data (sample [0,0,0]): {expected_img_sample[0,0,0]}")
# print(f" Resumed image data (sample [0,0,0]): {resumed_img_sample[0,0,0]}")
elif saved_iterator_state is not None: # If state was saved but verification couldn't complete
if expected_next_batch_from_iterator1 is None and resumed_batch is None:
print("\nVERIFICATION NOTE: Both iterators seem to be at the end of the dataset after the initial iterations. This is valid if the dataset was short.")
else:
print("\nVerification inconclusive: could not obtain both expected and resumed batches for comparison.")
print(f" expected_next_batch_from_iterator1 is None: {expected_next_batch_from_iterator1 is None}")
print(f" resumed_batch is None: {resumed_batch is None}")
else: # If DataLoader itself wasn't configured
print("DataLoader for Ex7 not configured.")
---
Congratulations on completing the exercises! You should now have a good understanding of:
DataSource, Sampler, Operations).grain.DataLoader for efficient data input, including parallel loading.DataLoader(worker_count > 0) for parallelism.RandomMapTransform's provided rng.grain.sharding.ShardByJaxProcess (or manual ShardOptions) for JAX sharding.Dropout or BatchNorm) has changed. You will need to use a migration script to update old checkpoints to the v0.11 format, as described in the official migration guide.DataSource implementations or by integrating with libraries like TFDS.We hope these exercises have been helpful in your journey with JAX and Grain!