This Colab notebook is designed to provide you with hands-on experience of the concepts covered in the lecture. You'll work through examples of defining hardware meshes, specifying sharding for your data and model parameters, initializing large models in a distributed fashion, and setting up a sharded training step using JAX's powerful SPMD (Single Program, Multiple Data) capabilities with Flax NNX.
Let's dive into scaling your JAX models!
# @title Setup and Imports
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import numpy as np
import flax.nnx as nnx
from flax.nnx import spmd # For sharding utilities like get_partition_spec
import chex
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import optax # For the training loop example
# --- IMPORTANT: Simulate a multi-device environment ---
# This allows us to run sharding examples in Colab which typically has only one accelerator.
# We'll simulate 8 CPU devices.
try:
chex.set_n_cpu_devices(8)
except RuntimeError as e:
print(f"Note: {e}. This is expected if devices are already set or on a real multi-device TPU/GPU setup.")
# Verify the number of devices JAX sees
print(f"JAX visible devices: {jax.devices()}")
print(f"Number of JAX devices: {jax.device_count()}")
# Helper function to create RNG keys for NNX
PRNGKey = jax.random.PRNGKey
# For NNX module parameter initialization
def new_key():
return nnx.make_rng('params')
# For general JAX operations
MAIN_KEY = PRNGKey(0)
# Silence some NNX warnings for cleaner output in the notebook
import logging
logging.getLogger('flax.experimental.nnx.nnx.graph').setLevel(logging.ERROR)
The foundation of explicit parallelism in JAX involves:
Mesh: A logical grid representing your physical accelerator devices.PartitionSpec (often aliased as P): A tuple describing how a tensor's dimensions map to Mesh axes.NamedSharding: Combines a Mesh and a PartitionSpec into a reusable sharding strategy.jax.device_put: Explicitly places data onto devices with a specific NamedSharding.# Exercise 1: Instructions Cell
# TODO: 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.
# Tip: jax.devices() gives a list of devices. Reshape a numpy array of these devices.
mesh_devices = np.array(jax.devices()).reshape((2, 4)) # Devices for 'data' axis, then 'model' axis
mesh = Mesh(mesh_devices, axis_names=('data', 'model'))
# TODO: 2. Define PartitionSpecs
pspec_data_parallel = P('data', None) # Shard batch dim, replicate feature dim
pspec_model_parallel_dim1 = P(None, 'model') # Replicate in_feature dim, shard out_feature dim
pspec_replicated = P() # Fully replicated
# TODO: 3. Create NamedSharding for data parallelism
data_named_sharding = NamedSharding(mesh, pspec_data_parallel)
# TODO: 4. Create a sample NumPy array
numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))
# TODO: 5. Shard the array using jax.device_put
sharded_array = jax.device_put(numpy_array, data_named_sharding)
# TODO: 6. Print mesh and sharded array's sharding
print("Mesh:", mesh)
print("\nPartitionSpec for data parallel:", pspec_data_parallel)
print("PartitionSpec for model parallel (dim 1):", pspec_model_parallel_dim1)
print("PartitionSpec for replicated:", pspec_replicated)
print("\nNamedSharding for data:", data_named_sharding)
print("\nOriginal NumPy array shape:", numpy_array.shape)
print("Sharded JAX array sharding:", sharded_array.sharding)
print("Sharded JAX array device buffers (should show multiple devices):")
for buffer in sharded_array.addressable_shards:
print(f"Buffer {i}: Device={db.device}, Shape={db.data.shape}")
# Verify that each device in the 'data' axis gets a slice
# For a (2,4) mesh with P('data', None) on a (16,128) array:
# Data axis has 2 devices. 16 / 2 = 8.
# Each of the 4 devices along 'model' axis within a data-slice group will get a replica of that (8,128) slice.
# So, each of the 8 devices will hold a slice of shape (8, 128) effectively,
# but logically the sharding is over the 'data' axis.
# The `shard_shape` method reflects the shape on one device.
# Total elements: 16*128 = 2048. Each shard has 8*128 = 1024 elements.
# 2048 / 8 devices = 256 elements per device if fully sharded.
# Here, with P('data', None), it's 16/2 = 8. So shape on each device is (8, 128).
jax.debug.visualize_array_sharding(sharded_array)
# Exercise 1: Solution Cell
# 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.
# The 'data' axis will have 2 devices, and the 'model' axis will have 4 devices.
mesh_devices = np.array(jax.devices()).reshape((2, 4))
mesh = Mesh(mesh_devices, axis_names=('data', 'model'))
# 2. Define PartitionSpecs
pspec_data_parallel = P('data', None) # Shard batch (dim 0) along 'data', replicate features (dim 1)
pspec_model_parallel_dim1 = P(None, 'model') # Replicate dim 0, shard dim 1 along 'model'
pspec_replicated = P() # Fully replicated across all devices in the mesh
# 3. Create NamedSharding for data parallelism
data_named_sharding = NamedSharding(mesh, pspec_data_parallel)
# 4. Create a sample NumPy array
numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) # (batch_size=16, features=128)
# 5. Shard the array using jax.device_put
# This places the array onto the devices defined by the mesh, according to the NamedSharding.
sharded_array = jax.device_put(numpy_array, data_named_sharding)
# 6. Print mesh and sharded array's sharding
print("Mesh:", mesh)
print("\nPartitionSpec for data parallel:", pspec_data_parallel)
print("PartitionSpec for model parallel (dim 1):", pspec_model_parallel_dim1)
print("PartitionSpec for replicated:", pspec_replicated)
print("\nNamedSharding for data:", data_named_sharding)
print("\nOriginal NumPy array shape:", numpy_array.shape)
print("Sharded JAX array object:", sharded_array)
print("Sharded JAX array sharding:", sharded_array.sharding)
print("\nInspecting device buffers for the sharded array:")
# For a (16, 128) array sharded with P('data', None) on a ('data':2, 'model':4) mesh:
# The 'data' axis (size 2) shards the first dimension (16). So, 16/2 = 8.
# The second dimension (128) is replicated (None).
# Each device will hold a piece of shape (8, 128).
for i, db in enumerate(sharded_array.addressable_shards):
# Access shape from the data attribute of the Shard object
print(f"Buffer {i}: Device={db.device}, Shape={db.data.shape}")
# You should see that the array is split across devices.
# With P('data', None) on a 2x4 mesh, the first dimension (16) is split over the 'data' axis (size 2).
# So, devices (0,0),(0,1),(0,2),(0,3) will get the first half of the data (rows 0-7), replicated.
# And devices (1,0),(1,1),(1,2),(1,3) will get the second half (rows 8-15), replicated.
# Each device buffer will have shape (8, 128).
jax.debug.visualize_array_sharding(sharded_array)
Flax NNX modules can store sharding hints (PartitionSpec tuples) directly within their parameter metadata. This is crucial for guiding the JAX compiler when performing sharded initialization and distributed training.
Instructions:init, initialize kernel (a 2D weight matrix) and bias (a 1D vector) as nnx.Param.# Exercise 2: Instructions Cell
class SimpleLinear(nnx.Module):
def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
key = rngs.params()
# TODO: 3. Initialize kernel and bias with sharding metadata
# Kernel sharding: P(None, 'model')
# Bias sharding: P('model')
# Use nnx.initializers.lecun_normal() for kernel and nnx.initializers.zeros for bias
# Example of using nnx.with_metadata:
# self.my_param = nnx.Param(
# nnx.with_metadata(
# nnx.initializers.zeros,
# sharding=P(...) # Your PartitionSpec tuple
# )(key, param_shape)
# )
# Or, more directly if nnx.Param supports 'sharding' kwarg for its value:
# self.my_param = nnx.Param(
# nnx.initializers.zeros(key, param_shape),
# sharding=P(...)
# )
# The slides show both `nnx.with_metadata` and direct `sharding=` to nnx.Param.
# Let's use nnx.with_metadata for clarity as it's explicitly for metadata.
self.kernel = nnx.Param(
nnx.with_metadata(
nnx.initializers.lecun_normal(),
sharding=P(None, 'model') # Shard out_features along 'model'
)(key, (in_features, out_features))
)
self.bias = nnx.Param(
nnx.with_metadata(
nnx.initializers.zeros,
sharding=P('model') # Shard bias along 'model'
)(key, (out_features,))
)
self.in_features = in_features
self.out_features = out_features
def __call__(self, x: jax.Array):
# This part is not the focus of sharding annotation, but good to have
return x @ self.kernel.value + self.bias.value
# TODO: 4. Instantiate the module
rngs_init = nnx.Rngs(params=jax.random.key(0)) # Create Rngs for NNX module
linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_init)
# 5. Print the .sharding metadata from the kernel
# The sharding information is stored as metadata
print(f"Type of kernel: {type(linear_module_annotated.kernel)}")
print(f"Type of kernel's value: {type(linear_module_annotated.kernel.value)}") # JAX array with metadata
# Access the sharding from the value
print("\nKernel sharding metadata:", linear_module_annotated.kernel.sharding)
print("Bias sharding metadata:", linear_module_annotated.bias.sharding)
# Verify output:
# Kernel sharding metadata: PartitionSpec(None, 'model')
# Bias sharding metadata: PartitionSpec('model',)
# Exercise 2: Solution Cell
class SimpleLinear(nnx.Module):
def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
key = rngs.params() # Get a JAX PRNGKey for parameter initialization
# 3. Initialize kernel and bias with sharding metadata
# Kernel: shape (in_features, out_features), shard out_features along 'model'
self.kernel = nnx.Param(
nnx.with_metadata(
nnx.initializers.lecun_normal(), # Initializer function
sharding=P(None, 'model') # PartitionSpec tuple for metadata
)(key, (in_features, out_features)) # Call initializer with key and shape
)
# Bias: shape (out_features,), shard along 'model'
self.bias = nnx.Param(
nnx.with_metadata(
nnx.initializers.zeros, # Initializer function
sharding=P('model') # PartitionSpec tuple for metadata
)(key, (out_features,)) # Call initializer with key and shape
)
self.in_features = in_features
self.out_features = out_features
def __call__(self, x: jax.Array):
return x @ self.kernel.value + self.bias.value
# 4. Instantiate the module
# We need an Rngs object for NNX modules, even if just for 'params'
rngs_for_module_init = nnx.Rngs(params=jax.random.key(0))
linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_for_module_init)
# 5. Print the .sharding metadata from kernel
# The sharding information is stored as metadata
print(f"Type of kernel: {type(linear_module_annotated.kernel)}")
print(f"Type of kernel's value: {type(linear_module_annotated.kernel.value)}") # JAX array with metadata
# Access the sharding from the value
print("\nKernel sharding metadata:", linear_module_annotated.kernel.sharding)
print("Bias sharding metadata:", linear_module_annotated.bias.sharding)
# Verify output:
# Kernel sharding metadata: PartitionSpec(None, 'model')
# Bias sharding metadata: PartitionSpec('model',)
Initializing a very large model directly can cause Out-Of-Memory (OOM) errors if all parameters are created on a single default device. The solution is to perform initialization and apply sharding constraints inside a JIT-compiled function executed within a Mesh context.
Workflow:# Exercise 3: Instructions Cell
# 1. Reuse SimpleLinear (already defined above)
# 2. Define a Mesh. For this exercise, let's assume we want to shard
# the 'model' dimension of SimpleLinear across all 8 devices.
# A (1, 8) mesh with axes ('data', 'model') would mean 'data' axis has size 1 (no data parallelism here)
# and 'model' axis has size 8.
mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))
mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model')) # or just ('model',) if using 1D mesh
# 3. Implement the sharded initialization function
@nnx.jit # nnx.jit handles split/merge of NNX state for JAX transformations
def create_sharded_linear_model(rngs_for_creation, in_f, out_f):
# Step 1: Instantiate the NNX module (parameters are created here, typically on default device initially)
model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)
# Step 2: Extract the functional State PyTree
state = nnx.state(model)
# Step 3: Extract the PartitionSpec PyTree from metadata
# This uses the .sharding attributes we defined in SimpleLinear's __init__
pspecs = nnx.spmd.get_partition_spec(state)
# Step 4: Apply sharding constraints to the State
# This tells the JAX compiler the desired final layout for the parameters.
# The actual resharding happens when this JITted function is executed.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
# Step 5: Update the original module object with the now sharded state
nnx.update(model, sharded_state)
# Step 6: Return the model (which now contains sharded parameters)
return model
# 4. Call the function within the Mesh context
rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1)) # Use a different key
# TODO: Call create_sharded_linear_model within the mesh_ex3 context
with mesh_ex3:
sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)
# 5. Verify sharding of the actual JAX arrays
# The .value of an nnx.Param is the JAX array
print("\n--- Verification after sharded initialization ---")
print("Sharded Kernel's JAX array sharding:", sharded_linear_model.kernel.value.sharding)
print("Sharded Bias's JAX array sharding:", sharded_linear_model.bias.value.sharding)
# Expected output for kernel: NamedSharding(mesh=..., spec=PartitionSpec(None, 'model'))
# Expected output for bias: NamedSharding(mesh=..., spec=PartitionSpec('model',))
# The mesh in NamedSharding should match mesh_ex3.
# Exercise 3: Solution Cell
# 1. Reuse SimpleLinear (already defined above)
# 2. Define a Mesh.
# We want to shard the 'model' dimension. Let's use a 1x8 mesh,
# dedicating all 8 devices to the 'model' axis for this example.
# 'data' axis size 1 means parameters are replicated along it (which is trivial here).
mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))
mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model'))
# 3. Implement the sharded initialization function
# Tell JAX that in_f and out_f are static arguments
@nnx.jit(static_argnums=(1, 2))
def create_sharded_linear_model(rngs_for_creation, in_f, out_f):
# Step 1: Instantiate the NNX module. Params are created with metadata hints.
# At this point inside JIT, they might be on a default device or abstract.
print(f"Inside JIT: Instantiating SimpleLinear({in_f}, {out_f})")
model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)
# Step 2: Extract the functional State PyTree. This is JAX-compatible.
state = nnx.state(model)
# print(f"Inside JIT: Extracted state - Kernel sharding metadata: {state['kernel'].sharding}")
# Step 3: Extract the PartitionSpec PyTree from metadata.
pspecs = nnx.spmd.get_partition_spec(state)
# print(f"Inside JIT: Extracted PartitionSpecs - Kernel PSpec: {pspecs['kernel']}")
# Step 4: Apply sharding constraints to the State.
# This is a hint to XLA; the actual sharding occurs when data is materialized on devices.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
# print(f"Inside JIT: Applied sharding constraint. Kernel value sharding (if available): {getattr(sharded_state['kernel'].value, 'sharding', 'Not yet concrete')}")
# Step 5: Update the original module object with the (conceptually) sharded state.
nnx.update(model, sharded_state)
# Step 6: Return the model.
return model
# 4. Call the function within the Mesh context
# The 'with mesh:' block provides the context for JAX to fulfill the sharding.
rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1))
print(f"Executing create_sharded_linear_model within mesh: {mesh_ex3}")
with mesh_ex3:
sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)
print("Sharded model created.")
# 5. Verify sharding of the actual JAX arrays
# After execution, the .value of nnx.Param should be a sharded JAX GlobalDeviceArray (GDA).
print("\n--- Verification after sharded initialization ---")
print("Sharded Kernel's JAX array (.value) sharding:", sharded_linear_model.kernel.value.sharding)
print("Sharded Kernel's JAX array shape:", sharded_linear_model.kernel.value.shape)
print("Sharded Bias's JAX array (.value) sharding:", sharded_linear_model.bias.value.sharding)
print("Sharded Bias's JAX array shape:", sharded_linear_model.bias.value.shape)
# For kernel (128, 256) with P(None, 'model') on a ('data':1, 'model':8) mesh:
# The 'model' axis (size 8) shards the second dimension (256). So, 256/8 = 32.
# Each device buffer for the kernel should have shape (128, 32).
print("\nKernel device buffers:")
for i, db in enumerate(sharded_linear_model.kernel.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_linear_model.kernel.value)
# For bias (256,) with P('model') on a ('data':1, 'model':8) mesh:
# The 'model' axis (size 8) shards the first dimension (256). So, 256/8 = 32.
# Each device buffer for the bias should have shape (32,).
print("\nBias device buffers:")
for i, db in enumerate(sharded_linear_model.bias.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_linear_model.bias.value)
Let's apply these concepts to a slightly more complex block, akin to a part of a Transformer's FeedForward network: LayerNorm -> Linear1 -> GELU -> Linear2. We'll focus on model parallelism for the Linear layers and LayerNorm parameters.
Instructions:# Exercise 4: Instructions Cell
class NNXFeedForward(nnx.Module):
def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):
key_param, key_dropout = rngs.fork_key('params'), rngs.fork_key('dropout') # Example if using dropout
# TODO: 1. Define LayerNorm and Linear layers with sharding metadata
# LayerNorm: scale P('model'), bias P('model')
# Linear1 (embed_dim -> ff_dim): kernel P(None, 'model'), bias P('model')
# Linear2 (ff_dim -> embed_dim): kernel P(None, 'model'), bias P('model')
# (Note: For Linear2 kernel P('model', None) would shard along input dim,
# P(None, 'model') shards along output dim. Let's be consistent with typical model sharding.)
self.layernorm = nnx.LayerNorm(
num_features=embed_dim,
epsilon=1e-6,
scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),
bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),
rngs=rngs.fork('params') # LayerNorm takes rngs for its own init
)
self.linear1 = SimpleLinear( # Reusing SimpleLinear for convenience
in_features=embed_dim,
out_features=ff_dim,
rngs=rngs.fork('params') # Pass down Rngs
)
# Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')
self.linear2 = SimpleLinear(
in_features=ff_dim,
out_features=embed_dim,
rngs=rngs.fork('params')
)
# Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')
# If we wanted linear2 kernel to be P('model', None), we'd need to modify SimpleLinear
# or create a new Linear variant. For now, let's assume SimpleLinear is P(None, 'model') for kernel.
def __call__(self, x: jax.Array, training: bool = False):
x_norm = self.layernorm(x)
x_ff = nnx.gelu(self.linear1(x_norm))
output = self.linear2(x_ff)
return output
# TODO: 2. Implement the sharded initialization function for NNXFeedForward
@nnx.jit(static_argnums=(1, 2))
def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):
model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)
state = nnx.state(model)
pspecs = nnx.spmd.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
return model
# TODO: 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)
mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model')) # 'model' axis has size 4
# TODO: 4. Instantiate and verify sharding
rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream
with mesh_ex4:
sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)
print("\n--- FFN Model Parameter Sharding Verification ---")
print("LayerNorm scale sharding:", sharded_ffn_model.layernorm.scale.value.sharding)
print("LayerNorm bias sharding:", sharded_ffn_model.layernorm.bias.value.sharding)
print("Linear1 kernel sharding:", sharded_ffn_model.linear1.kernel.value.sharding)
print("Linear2 kernel sharding:", sharded_ffn_model.linear2.kernel.value.sharding)
# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')
# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.
print(f"\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}") # Expected shape (32,)
jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)
# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')
# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.
# Expected shape on device (128, 128).
print(f"\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)
# Exercise 4: Solution Cell
# We need to ensure SimpleLinear used inside NNXFeedForward has the correct sharding.
# Let's redefine it or make it flexible if needed.
# The SimpleLinear from Ex2 already has:
# kernel: P(None, 'model'), bias: P('model')
# This is suitable for our FFN.
class NNXFeedForward(nnx.Module):
def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):
# 1. Define LayerNorm and Linear layers with sharding metadata
# LayerNorm takes an Rngs object directly if it has its own params/dropout.
# For nnx.LayerNorm, scale_init/bias_init are callables that take a key.
# We can use nnx.with_metadata with these initializers.
self.layernorm = nnx.LayerNorm(
num_features=embed_dim,
epsilon=1e-6,
# nnx.LayerNorm will call these initializers with a key from its rngs
scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),
bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),
rngs=rngs # Pass the Rngs for LayerNorm to use for its initializers
)
# For SimpleLinear, we pass the Rngs object, and it extracts the 'params' key.
self.linear1 = SimpleLinear(
in_features=embed_dim,
out_features=ff_dim,
rngs=rngs # Each submodule gets its own Rngs
)
# SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')
self.linear2 = SimpleLinear(
in_features=ff_dim,
out_features=embed_dim,
rngs=rngs
)
# SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')
def __call__(self, x: jax.Array, training: bool = False):
x_norm = self.layernorm(x)
x_ff = nnx.gelu(self.linear1(x_norm)) # linear1.__call__
output = self.linear2(x_ff) # linear2.__call__
return output
# 2. Implement the sharded initialization function for NNXFeedForward
@nnx.jit(static_argnums=(1, 2))
def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):
print(f"Inside JIT (FFN): Instantiating NNXFeedForward({embed_dim}, {ff_dim})")
model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)
state = nnx.state(model)
pspecs = nnx.spmd.get_partition_spec(state)
# print(f"Inside JIT (FFN): PSPECS = {pspecs}")
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
return model
# 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)
# 'data' axis size 2, 'model' axis size 4.
mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model'))
# 4. Instantiate and verify sharding
# Top-level Rngs for the FFN model creation
rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream
print(f"Executing create_sharded_ffn_model within mesh: {mesh_ex4}")
with mesh_ex4:
sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)
print("Sharded FFN model created.")
print("\n--- FFN Model Parameter Sharding Verification ---")
# LayerNorm parameters are typically 'scale' and 'bias'
print("LayerNorm scale sharding:", sharded_ffn_model.layernorm.scale.value.sharding)
print("LayerNorm bias sharding:", sharded_ffn_model.layernorm.bias.value.sharding)
print("Linear1 kernel sharding:", sharded_ffn_model.linear1.kernel.value.sharding)
print("Linear1 bias sharding:", sharded_ffn_model.linear1.bias.value.sharding)
print("Linear2 kernel sharding:", sharded_ffn_model.linear2.kernel.value.sharding)
print("Linear2 bias sharding:", sharded_ffn_model.linear2.bias.value.sharding)
# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')
# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.
print(f"\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}") # Expected shape (32,)
jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)
# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')
# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.
# Expected shape on device (128, 128).
print(f"\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):
print(f" Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)
A distributed training loop involves:
# Exercise 5: Instructions Cell
# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.
# Ensure they are available in this cell's scope.
# If not, you might need to re-run parts of Ex 4 or redefine them here.
# For simplicity, let's assume sharded_ffn_model and mesh_ex4 are accessible.
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
print("Please re-run Exercise 4 to define sharded_ffn_model and mesh_ex4.")
# As a fallback for running this cell independently:
_mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(_mesh_devices_ex4, axis_names=('data', 'model'))
_rngs_ffn_init = nnx.Rngs(params=jax.random.key(2), layernorm_params=jax.random.key(3),
linear1_params=jax.random.key(4), linear2_params=jax.random.key(5))
with mesh_ex4:
sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init, embed_dim=128, ff_dim=512)
# 2. Create a dummy input batch and labels
BATCH_SIZE = 16
EMBED_DIM = 128 # Must match sharded_ffn_model.layernorm.num_features
NUM_CLASSES = 10 # For dummy labels
numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)
numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)
# TODO: 3. Define NamedSharding for input batch (shard along 'data') and labels
# Batch: P('data', None)
# Labels: P('data')
batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))
label_input_sharding = NamedSharding(mesh_ex4, P('data'))
# TODO: 4. Shard the input batch and labels using jax.device_put (within mesh context is best)
# This should ideally be done inside the loop or just before calling train_step,
# and within the mesh context if the jax.device_put itself needs that context for device assignment.
# For jax.device_put, the mesh context isn't strictly necessary if NamedSharding already has the mesh.
with mesh_ex4: # Good practice to do device_put within mesh context
sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)
sharded_labels = jax.device_put(numpy_labels, label_input_sharding)
print("Sharded batch sharding:", sharded_batch.sharding)
print("Sharded labels sharding:", sharded_labels.sharding)
# TODO: 5. Define the train_step function
@nnx.jit
def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
# Define loss_fn for nnx.value_and_grad
# It operates on the stateful NNX model directly
def loss_fn(mdl_stateful: NNXFeedForward):
logits = mdl_stateful(batch, training=True) # Forward pass
# For FFN, output is (BATCH_SIZE, EMBED_DIM). For classification, it would go to (BATCH_SIZE, NUM_CLASSES)
# Let's assume our FFN output is used as logits for simplicity, though dimensions might not match num_classes.
# We'll average logits to NUM_CLASSES for a dummy loss.
# A real scenario would have a final classification layer.
if logits.shape[-1] != NUM_CLASSES:
# Crude way to make shapes match for dummy loss: average features to NUM_CLASSES channels
logits_for_loss = jnp.mean(logits.reshape(logits.shape[0], -1, NUM_CLASSES), axis=1)
if logits_for_loss.shape[0] == 0: # Handle BATCH_SIZE / num_data_devices = 0 case
logits_for_loss = jnp.zeros((logits.shape[0], NUM_CLASSES))
else:
logits_for_loss = logits
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits_for_loss, labels))
return loss
# nnx.value_and_grad handles model state splitting/merging
loss_value, grads = nnx.value_and_grad(loss_fn)(model)
# Optimizer updates model parameters (and its own state) in-place (conceptually for NNX)
# The actual update happens via nnx.update internally by the optimizer on the model's state.
optimizer.update(model, grads) # This will update sharded_ffn_model's parameters
return loss_value
# TODO: 6. Create an nnx.Optimizer for the sharded_ffn_model
# The optimizer will manage the sharded parameters and its own sharded state.
# When optimizer is created from a model with sharded state, its own state (e.g. momentum)
# should also be sharded appropriately.
# For nnx.Optimizer, we pass the model itself.
optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)
# TODO: 7. Execute the train_step once within the mesh context.
with mesh_ex4:
loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)
print(f"\nComputed loss: {loss}")
# Also verify that parameters in sharded_ffn_model have been updated (not easily visible by value change without running more steps)
# but the optimizer.update call should have functioned on sharded grads and params.
# Exercise 5: Solution Cell
# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.
# (Assuming they are available from Exercise 4 execution)
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 5.")
_mesh_devices_ex4_sol = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(_mesh_devices_ex4_sol, axis_names=('data', 'model'))
_rngs_ffn_init_sol = nnx.Rngs(
params=jax.random.key(2),
layernorm_params=jax.random.key(30),
linear1_params=jax.random.key(40),
linear2_params=jax.random.key(50)
)
with mesh_ex4:
sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_sol, embed_dim=128, ff_dim=512)
print("Fallback definitions complete.")
# 2. Create a dummy input batch and labels
BATCH_SIZE = 16 # Global batch size
EMBED_DIM = 128 # Must match sharded_ffn_model's input embed_dim
# Output of FFN is also EMBED_DIM. For classification, a final Linear layer to NUM_CLASSES is needed.
# For this exercise, we'll make a dummy adjustment if NUM_CLASSES doesn't match.
NUM_CLASSES = 10 # Example number of classes for dummy loss
numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)
# Labels for classification, shape (BATCH_SIZE,)
numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)
# 3. Define NamedSharding for input batch and labels
# Batch shape (BATCH_SIZE, EMBED_DIM), sharded P('data', None) -> ('data' shards BATCH_SIZE)
batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))
# Labels shape (BATCH_SIZE,), sharded P('data') -> ('data' shards BATCH_SIZE)
label_input_sharding = NamedSharding(mesh_ex4, P('data'))
# 4. Shard the input batch and labels using jax.device_put
# This is typically done inside the training loop for each new batch.
# Performing it within the mesh context ensures devices align if mesh is complex.
with mesh_ex4:
sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)
sharded_labels = jax.device_put(numpy_labels, label_input_sharding)
print("Sharded batch object:", sharded_batch)
print("Sharded batch sharding:", sharded_batch.sharding)
# With ('data':2, 'model':4) mesh and P('data', None) for (16, 128) batch:
# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.
# Each device buffer will have shape (8, 128).
print(f"Sharded batch per-device shape: {sharded_batch.addressable_shards[0].data.shape}")
print("\nSharded labels object:", sharded_labels)
print("Sharded labels sharding:", sharded_labels.sharding)
# With ('data':2, 'model':4) mesh and P('data') for (16,) labels:
# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.
# Each device buffer will have shape (8,).
print(f"Sharded labels per-device shape: {sharded_labels.addressable_shards[0].data.shape}")
# 5. Define the train_step function
@nnx.jit
def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
# This loss_fn is defined inside train_step to capture 'batch' and 'labels'
# It takes the stateful NNX model as its argument.
def loss_fn(mdl_stateful: NNXFeedForward):
# Forward pass through the model
logits = mdl_stateful(batch, training=True) # model's __call__
# The FFN output is (BATCH_SIZE_PER_DEVICE, EMBED_DIM).
# For a typical classification loss, we'd need (BATCH_SIZE_PER_DEVICE, NUM_CLASSES).
# This is a placeholder to make the loss calculation work.
# In a real model, you'd have a final nnx.Linear layer projecting to NUM_CLASSES.
current_out_features = logits.shape[-1]
if current_out_features != NUM_CLASSES:
# Simple (and somewhat arbitrary) projection for the sake of the exercise
# This ensures the logits match the label dimensions for softmax_cross_entropy
# A more realistic approach for a final layer would be needed in practice.
# We are on a per-device shard of the batch here.
# print(f"Logits shape before adjustment: {logits.shape}") # For debugging JIT prints
# This projection is not well-posed for learning but allows loss computation.
if logits.shape[0] > 0 : # check if batch per device is not zero
# Create a dummy projection matrix on the fly (not trained)
# This is just to make the shapes work for the loss function.
# This is NOT how you would typically do a projection in a real model.
dummy_projection_key = jax.random.key(99) # Fixed key for reproducibility inside JIT
projection_matrix = jax.random.normal(dummy_projection_key, (current_out_features, NUM_CLASSES))
projected_logits = logits @ projection_matrix
else: # if batch per device is zero, create zero logits
projected_logits = jnp.zeros((logits.shape[0], NUM_CLASSES), dtype=logits.dtype)
else:
projected_logits = logits
# Compute loss. JAX automatically handles SPMD for this calculation
# if inputs (logits, labels) are sharded. Gradients will also be sharded,
# and all-reduce for gradients across data-parallel dimension is inserted by JAX.
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(projected_logits, labels))
return loss
# Compute loss and gradients. nnx.value_and_grad handles splitting/merging model state.
loss_value, grads = nnx.value_and_grad(loss_fn)(model)
# Apply gradients. The optimizer updates the model's parameters in-place (conceptually).
# If model parameters are sharded, optimizer state (e.g., momentum) is also sharded,
# and updates are applied in a distributed manner.
optimizer.update(model, grads)
return loss_value
# 6. Create an nnx.Optimizer for the sharded_ffn_model
# Pass the sharded model to the optimizer and specify what to differentiate wrt (nnx.Param).
# NNX ensures that the optimizer's state (like Adam's moments) is initialized
# with the same sharding as the parameters.
optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)
print(f"\nOptimizer created. Optimizer state sharding should match param sharding.")
# You can inspect optimizer.state to see its structure and (if concrete) sharding.
# 7. Execute the train_step once within the mesh context.
# All inputs (model params via optimizer, batch, labels) are sharded.
# JAX/XLA compiles the train_step for SPMD execution.
print(f"\nExecuting train_step within mesh: {mesh_ex4}")
with mesh_ex4:
# JIT compilation happens on the first call
loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)
# Subsequent calls would be faster
# loss_2 = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)
print(f"\nComputed loss from sharded train_step: {loss}")
# print(f"Computed loss (2nd step): {loss_2}")
# Note: The actual loss value isn't meaningful due to dummy projection and random data.
# The key is that the distributed computation ran.
To save/load huge sharded models without OOM, libraries like Orbax are used. Orbax needs to know the target NamedSharding for each parameter to restore it correctly. nnx.spmd.getnamedsharding helps generate this.
Instructions:# Exercise 6: Instructions Cell
# 1. Use sharded_ffn_model and mesh_ex4
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.")
_mesh_devices_ex6 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(_mesh_devices_ex6, axis_names=('data', 'model')) # Re-assign to global mesh_ex4 if needed
_rngs_ffn_init_ex6 = nnx.Rngs(
params=jax.random.key(20),
layernorm_params=jax.random.key(300),
linear1_params=jax.random.key(400),
linear2_params=jax.random.key(500)
)
with mesh_ex4: # Ensure mesh_ex4 is correctly used
sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6, embed_dim=128, ff_dim=512)
# TODO: 2. Get the state structure of the model
# This can be from the concrete sharded model, or an abstract model from nnx.eval_shape
model_state_structure = nnx.state(sharded_ffn_model)
# TODO: 3. Generate the target NamedSharding PyTree
# This uses the .sharding PartitionSpec metadata stored in the model's state
# and combines it with the provided mesh.
target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)
# TODO: 4. Print some of the generated NamedSharding objects
print("\n--- Target NamedShardings for Checkpointing ---")
print("NamedSharding for LayerNorm scale:")
nnx.display(target_named_shardings_tree['layernorm']['scale'])
print("\nNamedSharding for Linear1 kernel:")
nnx.display(target_named_shardings_tree['linear1']['kernel'])
print("\nNamedSharding for Linear2 bias:")
nnx.display(target_named_shardings_tree['linear2']['bias'])
# These NamedSharding objects tell a checkpointing library (like Orbax)
# exactly how each parameter should be laid out across the devices upon restoration.
# Exercise 6: Solution Cell
# 1. Use sharded_ffn_model and mesh_ex4
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.")
# Ensure mesh_ex4 is the one associated with sharded_ffn_model
_mesh_devices_ex6_sol = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(_mesh_devices_ex6_sol, axis_names=('data', 'model')) # Make sure this is the correct mesh
_rngs_ffn_init_ex6_sol = nnx.Rngs(
params=jax.random.key(201),
layernorm_params=jax.random.key(301),
linear1_params=jax.random.key(401),
linear2_params=jax.random.key(501)
)
with mesh_ex4: # Use the correct mesh_ex4
sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6_sol, embed_dim=128, ff_dim=512)
print("Fallback definitions complete for Ex6.")
# 2. Get the state structure of the model
# This PyTree has the same structure as the model's parameters,
# and each leaf (parameter state) contains the .sharding PartitionSpec metadata.
model_state_structure = nnx.state(sharded_ffn_model)
# 3. Generate the target NamedSharding PyTree
# nnx.spmd.get_named_sharding combines the PartitionSpec from metadata
# with the provided 'mesh_ex4' to create a full NamedSharding object for each parameter.
target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)
# 4. Print some of the generated NamedSharding objects
print("\n--- Target NamedShardings for Checkpointing (from nnx.spmd.get_named_sharding) ---")
# For LayerNorm's scale parameter
# Original sharding metadata (PartitionSpec): P('model')
# Mesh: ('data':2, 'model':4)
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))
print("\nNamedSharding for LayerNorm scale:")
# The path to scale might be model_state_structure['layernorm']['scale'].sharding
# So target_named_shardings_tree should have a similar path.
nnx.display(target_named_shardings_tree['layernorm']['scale'])
# For Linear1's kernel parameter
# Original sharding metadata (PartitionSpec): P(None, 'model')
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P(None, 'model'))
print("\nNamedSharding for Linear1 kernel:")
nnx.display(target_named_shardings_tree['linear1']['kernel'])
# For Linear2's bias parameter
# Original sharding metadata (PartitionSpec): P('model')
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))
print("\nNamedSharding for Linear2 bias:")
nnx.display(target_named_shardings_tree['linear2']['bias'])
# This target_named_shardings_tree would be passed to something like
# orbax.checkpoint.StandardRestore(target_named_shardings_tree)
# to tell Orbax how to reconstruct the sharded arrays when loading from a checkpoint.
You've now practiced:
These are foundational skills for scaling up your JAX and Flax NNX models.
Further Learning:
Please send us feedback at https://goo.gle/jax-training-feedback