Sharding and Parallelism with JAX & Flax NNX

Introduction

Welcome to the practical exercises for "Scaling Up: Sharding and Parallelism with JAX and Flax NNX"!

This Colab notebook is designed to provide you with hands-on experience of the concepts covered in the lecture. You'll work through examples of defining hardware meshes, specifying sharding for your data and model parameters, initializing large models in a distributed fashion, and setting up a sharded training step using JAX's powerful SPMD (Single Program, Multiple Data) capabilities with Flax NNX.

Let's dive into scaling your JAX models!

# @title Setup and Imports
# Install necessary libraries
!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import numpy as np
import flax.nnx as nnx
from flax.nnx import spmd # For sharding utilities like get_partition_spec
import chex
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import optax # For the training loop example

# --- IMPORTANT: Simulate a multi-device environment ---
# This allows us to run sharding examples in Colab which typically has only one accelerator.
# We'll simulate 8 CPU devices.
try:
    chex.set_n_cpu_devices(8)
except RuntimeError as e:
    print(f"Note: {e}. This is expected if devices are already set or on a real multi-device TPU/GPU setup.")

# Verify the number of devices JAX sees
print(f"JAX visible devices: {jax.devices()}")
print(f"Number of JAX devices: {jax.device_count()}")

# Helper function to create RNG keys for NNX
PRNGKey = jax.random.PRNGKey
# For NNX module parameter initialization
def new_key():
  return nnx.make_rng('params')

# For general JAX operations
MAIN_KEY = PRNGKey(0)

# Silence some NNX warnings for cleaner output in the notebook
import logging
logging.getLogger('flax.experimental.nnx.nnx.graph').setLevel(logging.ERROR)

Exercise 1: JAX Sharding Primitives - Mesh, PartitionSpec, NamedSharding, and device_put

The foundation of explicit parallelism in JAX involves:

  • Mesh: A logical grid representing your physical accelerator devices.
  • PartitionSpec (often aliased as P): A tuple describing how a tensor's dimensions map to Mesh axes.
  • NamedSharding: Combines a Mesh and a PartitionSpec into a reusable sharding strategy.
  • jax.device_put: Explicitly places data onto devices with a specific NamedSharding.
  • Instructions:
    1. Create a Mesh for our 8 simulated devices, arranging them in a 2x4 grid. Name the mesh axes 'data' and 'model'.
    2. Define three PartitionSpecs:
    - pspecdataparallel: Shard dimension 0 along 'data', replicate dimension 1. (Typical for input batches [batch, features]) - pspecmodelparalleldim1: Replicate dimension 0, shard dimension 1 along 'model'. (Typical for a weight matrix [infeatures, out_features] in some forms of tensor parallelism). - pspec_replicated: Fully replicate the tensor on all devices in the mesh.
    1. Create a NamedSharding object for pspecdataparallel using your mesh.
    2. Create a sample NumPy array of shape (16, 128).
    3. Use jax.device_put to shard this NumPy array according to the NamedSharding you created.
    4. Print the mesh and the .sharding attribute of the sharded JAX array to verify.
# Exercise 1: Instructions Cell

# TODO: 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.
# Tip: jax.devices() gives a list of devices. Reshape a numpy array of these devices.
mesh_devices = np.array(jax.devices()).reshape((2, 4)) # Devices for 'data' axis, then 'model' axis
mesh = Mesh(mesh_devices, axis_names=('data', 'model'))

# TODO: 2. Define PartitionSpecs
pspec_data_parallel = P('data', None)        # Shard batch dim, replicate feature dim
pspec_model_parallel_dim1 = P(None, 'model') # Replicate in_feature dim, shard out_feature dim
pspec_replicated = P()                       # Fully replicated

# TODO: 3. Create NamedSharding for data parallelism
data_named_sharding = NamedSharding(mesh, pspec_data_parallel)

# TODO: 4. Create a sample NumPy array
numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

# TODO: 5. Shard the array using jax.device_put
sharded_array = jax.device_put(numpy_array, data_named_sharding)

# TODO: 6. Print mesh and sharded array's sharding
print("Mesh:", mesh)
print("\nPartitionSpec for data parallel:", pspec_data_parallel)
print("PartitionSpec for model parallel (dim 1):", pspec_model_parallel_dim1)
print("PartitionSpec for replicated:", pspec_replicated)
print("\nNamedSharding for data:", data_named_sharding)
print("\nOriginal NumPy array shape:", numpy_array.shape)
print("Sharded JAX array sharding:", sharded_array.sharding)
print("Sharded JAX array device buffers (should show multiple devices):")
for buffer in sharded_array.addressable_shards:
  print(f"Buffer {i}: Device={db.device}, Shape={db.data.shape}")

# Verify that each device in the 'data' axis gets a slice
# For a (2,4) mesh with P('data', None) on a (16,128) array:
# Data axis has 2 devices. 16 / 2 = 8.
# Each of the 4 devices along 'model' axis within a data-slice group will get a replica of that (8,128) slice.
# So, each of the 8 devices will hold a slice of shape (8, 128) effectively,
# but logically the sharding is over the 'data' axis.
# The `shard_shape` method reflects the shape on one device.
# Total elements: 16*128 = 2048. Each shard has 8*128 = 1024 elements.
# 2048 / 8 devices = 256 elements per device if fully sharded.
# Here, with P('data', None), it's 16/2 = 8. So shape on each device is (8, 128).

jax.debug.visualize_array_sharding(sharded_array)

Exercise 1: Solution

# Exercise 1: Solution Cell

# 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.
# The 'data' axis will have 2 devices, and the 'model' axis will have 4 devices.
mesh_devices = np.array(jax.devices()).reshape((2, 4))
mesh = Mesh(mesh_devices, axis_names=('data', 'model'))

# 2. Define PartitionSpecs
pspec_data_parallel = P('data', None)        # Shard batch (dim 0) along 'data', replicate features (dim 1)
pspec_model_parallel_dim1 = P(None, 'model') # Replicate dim 0, shard dim 1 along 'model'
pspec_replicated = P()                       # Fully replicated across all devices in the mesh

# 3. Create NamedSharding for data parallelism
data_named_sharding = NamedSharding(mesh, pspec_data_parallel)

# 4. Create a sample NumPy array
numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) # (batch_size=16, features=128)

# 5. Shard the array using jax.device_put
# This places the array onto the devices defined by the mesh, according to the NamedSharding.
sharded_array = jax.device_put(numpy_array, data_named_sharding)

# 6. Print mesh and sharded array's sharding
print("Mesh:", mesh)
print("\nPartitionSpec for data parallel:", pspec_data_parallel)
print("PartitionSpec for model parallel (dim 1):", pspec_model_parallel_dim1)
print("PartitionSpec for replicated:", pspec_replicated)
print("\nNamedSharding for data:", data_named_sharding)
print("\nOriginal NumPy array shape:", numpy_array.shape)
print("Sharded JAX array object:", sharded_array)
print("Sharded JAX array sharding:", sharded_array.sharding)

print("\nInspecting device buffers for the sharded array:")
# For a (16, 128) array sharded with P('data', None) on a ('data':2, 'model':4) mesh:
# The 'data' axis (size 2) shards the first dimension (16). So, 16/2 = 8.
# The second dimension (128) is replicated (None).
# Each device will hold a piece of shape (8, 128).
for i, db in enumerate(sharded_array.addressable_shards):
  # Access shape from the data attribute of the Shard object
  print(f"Buffer {i}: Device={db.device}, Shape={db.data.shape}")

# You should see that the array is split across devices.
# With P('data', None) on a 2x4 mesh, the first dimension (16) is split over the 'data' axis (size 2).
# So, devices (0,0),(0,1),(0,2),(0,3) will get the first half of the data (rows 0-7), replicated.
# And devices (1,0),(1,1),(1,2),(1,3) will get the second half (rows 8-15), replicated.
# Each device buffer will have shape (8, 128).

jax.debug.visualize_array_sharding(sharded_array)

Exercise 2: Annotating Sharding in a Simple NNX Module

Flax NNX modules can store sharding hints (PartitionSpec tuples) directly within their parameter metadata. This is crucial for guiding the JAX compiler when performing sharded initialization and distributed training.

Instructions:
  1. Define a simple NNX Linear module.
  2. In its init, initialize kernel (a 2D weight matrix) and bias (a 1D vector) as nnx.Param.
  3. Use nnx.with_metadata to attach PartitionSpec sharding hints during initialization:
- For kernel (e.g., shape [infeatures, outfeatures]), shard its second dimension (output features) along a mesh axis named 'model'. Replicate the first dimension. So, P(None, 'model'). - For bias (e.g., shape [out_features]), shard it along the 'model' mesh axis. So, P('model').
  1. Instantiate the module (without a mesh context for now; these are just metadata annotations).
  2. Access the State of a parameter (e.g., module.kernel.state) and print its .sharding attribute to verify the PartitionSpec was stored.
# Exercise 2: Instructions Cell

class SimpleLinear(nnx.Module):
  def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    # TODO: 3. Initialize kernel and bias with sharding metadata
    # Kernel sharding: P(None, 'model')
    # Bias sharding: P('model')
    # Use nnx.initializers.lecun_normal() for kernel and nnx.initializers.zeros for bias

    # Example of using nnx.with_metadata:
    # self.my_param = nnx.Param(
    #   nnx.with_metadata(
    #       nnx.initializers.zeros,
    #       sharding=P(...) # Your PartitionSpec tuple
    #   )(key, param_shape)
    # )
    # Or, more directly if nnx.Param supports 'sharding' kwarg for its value:
    # self.my_param = nnx.Param(
    #     nnx.initializers.zeros(key, param_shape),
    #     sharding=P(...)
    # )
    # The slides show both `nnx.with_metadata` and direct `sharding=` to nnx.Param.
    # Let's use nnx.with_metadata for clarity as it's explicitly for metadata.

    self.kernel = nnx.Param(
        nnx.with_metadata(
            nnx.initializers.lecun_normal(),
            sharding=P(None, 'model') # Shard out_features along 'model'
        )(key, (in_features, out_features))
    )
    self.bias = nnx.Param(
        nnx.with_metadata(
            nnx.initializers.zeros,
            sharding=P('model') # Shard bias along 'model'
        )(key, (out_features,))
    )
    self.in_features = in_features
    self.out_features = out_features

  def __call__(self, x: jax.Array):
    # This part is not the focus of sharding annotation, but good to have
    return x @ self.kernel.value + self.bias.value

# TODO: 4. Instantiate the module
rngs_init = nnx.Rngs(params=jax.random.key(0)) # Create Rngs for NNX module
linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_init)

# 5. Print the .sharding metadata from the kernel
# The sharding information is stored as metadata
print(f"Type of kernel: {type(linear_module_annotated.kernel)}")
print(f"Type of kernel's value: {type(linear_module_annotated.kernel.value)}") # JAX array with metadata

# Access the sharding from the value
print("\nKernel sharding metadata:", linear_module_annotated.kernel.sharding)
print("Bias sharding metadata:", linear_module_annotated.bias.sharding)

# Verify output:
# Kernel sharding metadata: PartitionSpec(None, 'model')
# Bias sharding metadata: PartitionSpec('model',)

Exercise 2: Solution

# Exercise 2: Solution Cell

class SimpleLinear(nnx.Module):
  def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
    key = rngs.params() # Get a JAX PRNGKey for parameter initialization

    # 3. Initialize kernel and bias with sharding metadata
    # Kernel: shape (in_features, out_features), shard out_features along 'model'
    self.kernel = nnx.Param(
        nnx.with_metadata(
            nnx.initializers.lecun_normal(), # Initializer function
            sharding=P(None, 'model')        # PartitionSpec tuple for metadata
        )(key, (in_features, out_features))  # Call initializer with key and shape
    )

    # Bias: shape (out_features,), shard along 'model'
    self.bias = nnx.Param(
        nnx.with_metadata(
            nnx.initializers.zeros,         # Initializer function
            sharding=P('model')             # PartitionSpec tuple for metadata
        )(key, (out_features,))           # Call initializer with key and shape
    )
    self.in_features = in_features
    self.out_features = out_features

  def __call__(self, x: jax.Array):
    return x @ self.kernel.value + self.bias.value

# 4. Instantiate the module
# We need an Rngs object for NNX modules, even if just for 'params'
rngs_for_module_init = nnx.Rngs(params=jax.random.key(0))
linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_for_module_init)

# 5. Print the .sharding metadata from kernel
# The sharding information is stored as metadata
print(f"Type of kernel: {type(linear_module_annotated.kernel)}")
print(f"Type of kernel's value: {type(linear_module_annotated.kernel.value)}") # JAX array with metadata

# Access the sharding from the value
print("\nKernel sharding metadata:", linear_module_annotated.kernel.sharding)
print("Bias sharding metadata:", linear_module_annotated.bias.sharding)

# Verify output:
# Kernel sharding metadata: PartitionSpec(None, 'model')
# Bias sharding metadata: PartitionSpec('model',)

Exercise 3: Sharded Initialization of an NNX Module

Initializing a very large model directly can cause Out-Of-Memory (OOM) errors if all parameters are created on a single default device. The solution is to perform initialization and apply sharding constraints inside a JIT-compiled function executed within a Mesh context.

Workflow:
  1. Instantiate the unsharded NNX module (with sharding metadata already defined).
  2. Extract its functional State PyTree: state = nnx.state(model).
  3. Extract the PartitionSpec PyTree from metadata: pspecs = nnx.spmd.getpartitionspec(state).
  4. Apply sharding constraints to the State: shardedstate = jax.lax.withsharding_constraint(state, pspecs).
  5. Update the original module object with the sharded state: nnx.update(model, sharded_state).
  6. Return the model.
  7. Execute this entire JITted function within a jax.sharding.Mesh context.
Instructions:
  1. Reuse the SimpleLinear module definition from Exercise 2.
  2. Define a Mesh (e.g., 1x4, with axes ('data', 'model') or just 'model' if only model parallelism is intended for this part). Let's use (1, 8) with ('data', 'model') to make it clear we are sharding across the 'model' axis which has 8 devices.
  3. Implement a function createshardedlinearmodel(rngs, infeatures, out_features) that performs steps 1-6 above. Decorate this function with @nnx.jit.
  4. Call this createshardedlinear_model function within the Mesh context (using with mesh:).
  5. Verify that the parameters (kernel and bias) of the returned model are now physically sharded JAX arrays by printing their .sharding attribute (this time from the JAX array value, not the metadata).
# Exercise 3: Instructions Cell

# 1. Reuse SimpleLinear (already defined above)

# 2. Define a Mesh. For this exercise, let's assume we want to shard
#    the 'model' dimension of SimpleLinear across all 8 devices.
#    A (1, 8) mesh with axes ('data', 'model') would mean 'data' axis has size 1 (no data parallelism here)
#    and 'model' axis has size 8.
mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))
mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model')) # or just ('model',) if using 1D mesh

# 3. Implement the sharded initialization function
@nnx.jit # nnx.jit handles split/merge of NNX state for JAX transformations
def create_sharded_linear_model(rngs_for_creation, in_f, out_f):
  # Step 1: Instantiate the NNX module (parameters are created here, typically on default device initially)
  model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)

  # Step 2: Extract the functional State PyTree
  state = nnx.state(model)

  # Step 3: Extract the PartitionSpec PyTree from metadata
  # This uses the .sharding attributes we defined in SimpleLinear's __init__
  pspecs = nnx.spmd.get_partition_spec(state)

  # Step 4: Apply sharding constraints to the State
  # This tells the JAX compiler the desired final layout for the parameters.
  # The actual resharding happens when this JITted function is executed.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)

  # Step 5: Update the original module object with the now sharded state
  nnx.update(model, sharded_state)

  # Step 6: Return the model (which now contains sharded parameters)
  return model

# 4. Call the function within the Mesh context
rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1)) # Use a different key
# TODO: Call create_sharded_linear_model within the mesh_ex3 context
with mesh_ex3:
    sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)


# 5. Verify sharding of the actual JAX arrays
# The .value of an nnx.Param is the JAX array
print("\n--- Verification after sharded initialization ---")
print("Sharded Kernel's JAX array sharding:", sharded_linear_model.kernel.value.sharding)
print("Sharded Bias's JAX array sharding:", sharded_linear_model.bias.value.sharding)

# Expected output for kernel: NamedSharding(mesh=..., spec=PartitionSpec(None, 'model'))
# Expected output for bias: NamedSharding(mesh=..., spec=PartitionSpec('model',))
# The mesh in NamedSharding should match mesh_ex3.

Exercise 3: Solution

# Exercise 3: Solution Cell

# 1. Reuse SimpleLinear (already defined above)

# 2. Define a Mesh.
# We want to shard the 'model' dimension. Let's use a 1x8 mesh,
# dedicating all 8 devices to the 'model' axis for this example.
# 'data' axis size 1 means parameters are replicated along it (which is trivial here).
mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))
mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model'))

# 3. Implement the sharded initialization function
# Tell JAX that in_f and out_f are static arguments
@nnx.jit(static_argnums=(1, 2))
def create_sharded_linear_model(rngs_for_creation, in_f, out_f):
  # Step 1: Instantiate the NNX module. Params are created with metadata hints.
  # At this point inside JIT, they might be on a default device or abstract.
  print(f"Inside JIT: Instantiating SimpleLinear({in_f}, {out_f})")
  model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)

  # Step 2: Extract the functional State PyTree. This is JAX-compatible.
  state = nnx.state(model)
  # print(f"Inside JIT: Extracted state - Kernel sharding metadata: {state['kernel'].sharding}")


  # Step 3: Extract the PartitionSpec PyTree from metadata.
  pspecs = nnx.spmd.get_partition_spec(state)
  # print(f"Inside JIT: Extracted PartitionSpecs - Kernel PSpec: {pspecs['kernel']}")


  # Step 4: Apply sharding constraints to the State.
  # This is a hint to XLA; the actual sharding occurs when data is materialized on devices.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  # print(f"Inside JIT: Applied sharding constraint. Kernel value sharding (if available): {getattr(sharded_state['kernel'].value, 'sharding', 'Not yet concrete')}")


  # Step 5: Update the original module object with the (conceptually) sharded state.
  nnx.update(model, sharded_state)

  # Step 6: Return the model.
  return model

# 4. Call the function within the Mesh context
# The 'with mesh:' block provides the context for JAX to fulfill the sharding.
rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1))
print(f"Executing create_sharded_linear_model within mesh: {mesh_ex3}")
with mesh_ex3:
    sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)
print("Sharded model created.")

# 5. Verify sharding of the actual JAX arrays
# After execution, the .value of nnx.Param should be a sharded JAX GlobalDeviceArray (GDA).
print("\n--- Verification after sharded initialization ---")
print("Sharded Kernel's JAX array (.value) sharding:", sharded_linear_model.kernel.value.sharding)
print("Sharded Kernel's JAX array shape:", sharded_linear_model.kernel.value.shape)
print("Sharded Bias's JAX array (.value) sharding:", sharded_linear_model.bias.value.sharding)
print("Sharded Bias's JAX array shape:", sharded_linear_model.bias.value.shape)

# For kernel (128, 256) with P(None, 'model') on a ('data':1, 'model':8) mesh:
# The 'model' axis (size 8) shards the second dimension (256). So, 256/8 = 32.
# Each device buffer for the kernel should have shape (128, 32).
print("\nKernel device buffers:")
for i, db in enumerate(sharded_linear_model.kernel.value.addressable_shards):
  print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_linear_model.kernel.value)

# For bias (256,) with P('model') on a ('data':1, 'model':8) mesh:
# The 'model' axis (size 8) shards the first dimension (256). So, 256/8 = 32.
# Each device buffer for the bias should have shape (32,).
print("\nBias device buffers:")
for i, db in enumerate(sharded_linear_model.bias.value.addressable_shards):
  print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}")
jax.debug.visualize_array_sharding(sharded_linear_model.bias.value)

Exercise 4: Sharding a Mini FeedForward Block

Let's apply these concepts to a slightly more complex block, akin to a part of a Transformer's FeedForward network: LayerNorm -> Linear1 -> GELU -> Linear2. We'll focus on model parallelism for the Linear layers and LayerNorm parameters.

Instructions:
  1. Define an NNXFeedForward module containing:
- nnx.LayerNorm: Shard its scale and bias parameters along the 'model' axis (P('model')). - nnx.Linear (linear1): Kernel P(None, 'model'), bias P('model'). - nnx.Linear (linear2): Kernel P('model', None) (Note the change for variety, sharding input features), bias P(None) (replicated, or P() if not sharding bias). Let's use P(None, 'model') for kernel2 as well to be consistent with typical FFN output sharding, and bias P('model').
  1. Use the sharded initialization workflow (createshardedffn_model function decorated with @nnx.jit) to initialize this NNXFeedForward module.
  2. Define a 2D mesh, e.g., (2, 4) with axes ('data', 'model'). The 'model' axis will be used for sharding the parameters as defined.
  3. Instantiate and verify the sharding of parameters within this NNXFeedForward model.
# Exercise 4: Instructions Cell

class NNXFeedForward(nnx.Module):
  def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):
    key_param, key_dropout = rngs.fork_key('params'), rngs.fork_key('dropout') # Example if using dropout

    # TODO: 1. Define LayerNorm and Linear layers with sharding metadata
    # LayerNorm: scale P('model'), bias P('model')
    # Linear1 (embed_dim -> ff_dim): kernel P(None, 'model'), bias P('model')
    # Linear2 (ff_dim -> embed_dim): kernel P(None, 'model'), bias P('model')
    # (Note: For Linear2 kernel P('model', None) would shard along input dim,
    #  P(None, 'model') shards along output dim. Let's be consistent with typical model sharding.)

    self.layernorm = nnx.LayerNorm(
        num_features=embed_dim,
        epsilon=1e-6,
        scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),
        bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),
        rngs=rngs.fork('params') # LayerNorm takes rngs for its own init
    )

    self.linear1 = SimpleLinear( # Reusing SimpleLinear for convenience
        in_features=embed_dim,
        out_features=ff_dim,
        rngs=rngs.fork('params') # Pass down Rngs
    )
    # Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')

    self.linear2 = SimpleLinear(
        in_features=ff_dim,
        out_features=embed_dim,
        rngs=rngs.fork('params')
    )
    # Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')
    # If we wanted linear2 kernel to be P('model', None), we'd need to modify SimpleLinear
    # or create a new Linear variant. For now, let's assume SimpleLinear is P(None, 'model') for kernel.


  def __call__(self, x: jax.Array, training: bool = False):
    x_norm = self.layernorm(x)
    x_ff = nnx.gelu(self.linear1(x_norm))
    output = self.linear2(x_ff)
    return output

# TODO: 2. Implement the sharded initialization function for NNXFeedForward
@nnx.jit(static_argnums=(1, 2))
def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):
  model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)
  state = nnx.state(model)
  pspecs = nnx.spmd.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

# TODO: 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)
mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model')) # 'model' axis has size 4

# TODO: 4. Instantiate and verify sharding
rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream
with mesh_ex4:
    sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)

print("\n--- FFN Model Parameter Sharding Verification ---")
print("LayerNorm scale sharding:", sharded_ffn_model.layernorm.scale.value.sharding)
print("LayerNorm bias sharding:", sharded_ffn_model.layernorm.bias.value.sharding)
print("Linear1 kernel sharding:", sharded_ffn_model.linear1.kernel.value.sharding)
print("Linear2 kernel sharding:", sharded_ffn_model.linear2.kernel.value.sharding)

# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')
# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.
print(f"\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):
    print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}") # Expected shape (32,)

jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)

# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')
# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.
# Expected shape on device (128, 128).
print(f"\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):
    print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}")

jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)

Exercise 4: Solution

# Exercise 4: Solution Cell

# We need to ensure SimpleLinear used inside NNXFeedForward has the correct sharding.
# Let's redefine it or make it flexible if needed.
# The SimpleLinear from Ex2 already has:
# kernel: P(None, 'model'), bias: P('model')
# This is suitable for our FFN.

class NNXFeedForward(nnx.Module):
  def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):
    # 1. Define LayerNorm and Linear layers with sharding metadata
    # LayerNorm takes an Rngs object directly if it has its own params/dropout.
    # For nnx.LayerNorm, scale_init/bias_init are callables that take a key.
    # We can use nnx.with_metadata with these initializers.
    self.layernorm = nnx.LayerNorm(
        num_features=embed_dim,
        epsilon=1e-6,
        # nnx.LayerNorm will call these initializers with a key from its rngs
        scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),
        bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),
        rngs=rngs # Pass the Rngs for LayerNorm to use for its initializers
    )

    # For SimpleLinear, we pass the Rngs object, and it extracts the 'params' key.
    self.linear1 = SimpleLinear(
        in_features=embed_dim,
        out_features=ff_dim,
        rngs=rngs # Each submodule gets its own Rngs
    )
    # SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')

    self.linear2 = SimpleLinear(
        in_features=ff_dim,
        out_features=embed_dim,
        rngs=rngs
    )
    # SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')

  def __call__(self, x: jax.Array, training: bool = False):
    x_norm = self.layernorm(x)
    x_ff = nnx.gelu(self.linear1(x_norm)) # linear1.__call__
    output = self.linear2(x_ff)          # linear2.__call__
    return output

# 2. Implement the sharded initialization function for NNXFeedForward
@nnx.jit(static_argnums=(1, 2))
def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):
  print(f"Inside JIT (FFN): Instantiating NNXFeedForward({embed_dim}, {ff_dim})")
  model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)
  state = nnx.state(model)
  pspecs = nnx.spmd.get_partition_spec(state)
  # print(f"Inside JIT (FFN): PSPECS = {pspecs}")
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

# 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)
# 'data' axis size 2, 'model' axis size 4.
mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model'))

# 4. Instantiate and verify sharding
# Top-level Rngs for the FFN model creation
rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream

print(f"Executing create_sharded_ffn_model within mesh: {mesh_ex4}")
with mesh_ex4:
    sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)
print("Sharded FFN model created.")

print("\n--- FFN Model Parameter Sharding Verification ---")
# LayerNorm parameters are typically 'scale' and 'bias'
print("LayerNorm scale sharding:", sharded_ffn_model.layernorm.scale.value.sharding)
print("LayerNorm bias sharding:", sharded_ffn_model.layernorm.bias.value.sharding)
print("Linear1 kernel sharding:", sharded_ffn_model.linear1.kernel.value.sharding)
print("Linear1 bias sharding:", sharded_ffn_model.linear1.bias.value.sharding)
print("Linear2 kernel sharding:", sharded_ffn_model.linear2.kernel.value.sharding)
print("Linear2 bias sharding:", sharded_ffn_model.linear2.bias.value.sharding)


# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')
# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.
print(f"\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):
    print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}") # Expected shape (32,)

jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)

# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')
# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.
# Expected shape on device (128, 128).
print(f"\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:")
for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):
    print(f"  Buffer {i}: Device={db.device}, Shape={db.data.shape}")

jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)

Exercise 5: Sharding Input Data and a Mock Training Step

A distributed training loop involves:

  • Sharding input data batches.
  • A JIT-compiled training step that operates on the sharded model and sharded data.
  • Gradients are computed, and optimizer updates are applied in a distributed manner.
  • Instructions:
    1. Use the shardedffnmodel and mesh_ex4 (2x4, ('data', 'model')) from Exercise 4.
    2. Create a dummy input batch (e.g., NumPy array (batchsize=16, embeddim=128)).
    3. Define a NamedSharding to shard this batch along the 'data' axis of mesh_ex4 (i.e., P('data', None)).
    4. Use jax.device_put to shard the input batch. Also create sharded dummy labels P('data').
    5. Define a train_step function decorated with @nnx.jit. This function should:
    - Take the sharded model, an optimizer (e.g., nnx.Optimizer), sharded batch, and sharded labels as input. - Define a lossfn that takes the stateful model, performs a forward pass, and computes a simple mean loss (e.g., using optax.softmaxcrossentropywithintegerlabels). - Use nnx.valueandgrad(loss_fn)(model) to get loss and gradients. - Call optimizer.update(model, grads) to apply gradients. - Return the loss.
    1. Create a simple nnx.Optimizer (e.g., Adam) for the shardedffnmodel. Crucially, the optimizer state should also be sharded consistently with the parameters it optimizes. NNX's Optimizer typically handles this if created from a sharded model state.
    2. Execute the trainstep once with the sharded inputs and model within the meshex4 context. Print the resulting loss. (The actual numerics of the loss are not important, focus on the setup).
# Exercise 5: Instructions Cell

# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.
# Ensure they are available in this cell's scope.
# If not, you might need to re-run parts of Ex 4 or redefine them here.
# For simplicity, let's assume sharded_ffn_model and mesh_ex4 are accessible.
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
    print("Please re-run Exercise 4 to define sharded_ffn_model and mesh_ex4.")
    # As a fallback for running this cell independently:
    _mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))
    mesh_ex4 = Mesh(_mesh_devices_ex4, axis_names=('data', 'model'))
    _rngs_ffn_init = nnx.Rngs(params=jax.random.key(2), layernorm_params=jax.random.key(3),
                             linear1_params=jax.random.key(4), linear2_params=jax.random.key(5))
    with mesh_ex4:
        sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init, embed_dim=128, ff_dim=512)


# 2. Create a dummy input batch and labels
BATCH_SIZE = 16
EMBED_DIM = 128 # Must match sharded_ffn_model.layernorm.num_features
NUM_CLASSES = 10 # For dummy labels

numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)
numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)

# TODO: 3. Define NamedSharding for input batch (shard along 'data') and labels
# Batch: P('data', None)
# Labels: P('data')
batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))
label_input_sharding = NamedSharding(mesh_ex4, P('data'))


# TODO: 4. Shard the input batch and labels using jax.device_put (within mesh context is best)
# This should ideally be done inside the loop or just before calling train_step,
# and within the mesh context if the jax.device_put itself needs that context for device assignment.
# For jax.device_put, the mesh context isn't strictly necessary if NamedSharding already has the mesh.
with mesh_ex4: # Good practice to do device_put within mesh context
    sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)
    sharded_labels = jax.device_put(numpy_labels, label_input_sharding)

print("Sharded batch sharding:", sharded_batch.sharding)
print("Sharded labels sharding:", sharded_labels.sharding)


# TODO: 5. Define the train_step function
@nnx.jit
def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
  # Define loss_fn for nnx.value_and_grad
  # It operates on the stateful NNX model directly
  def loss_fn(mdl_stateful: NNXFeedForward):
    logits = mdl_stateful(batch, training=True) # Forward pass
    # For FFN, output is (BATCH_SIZE, EMBED_DIM). For classification, it would go to (BATCH_SIZE, NUM_CLASSES)
    # Let's assume our FFN output is used as logits for simplicity, though dimensions might not match num_classes.
    # We'll average logits to NUM_CLASSES for a dummy loss.
    # A real scenario would have a final classification layer.
    if logits.shape[-1] != NUM_CLASSES:
        # Crude way to make shapes match for dummy loss: average features to NUM_CLASSES channels
        logits_for_loss = jnp.mean(logits.reshape(logits.shape[0], -1, NUM_CLASSES), axis=1)
        if logits_for_loss.shape[0] == 0: # Handle BATCH_SIZE / num_data_devices = 0 case
            logits_for_loss = jnp.zeros((logits.shape[0], NUM_CLASSES))

    else:
        logits_for_loss = logits

    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits_for_loss, labels))
    return loss

  # nnx.value_and_grad handles model state splitting/merging
  loss_value, grads = nnx.value_and_grad(loss_fn)(model)

  # Optimizer updates model parameters (and its own state) in-place (conceptually for NNX)
  # The actual update happens via nnx.update internally by the optimizer on the model's state.
  optimizer.update(model, grads) # This will update sharded_ffn_model's parameters

  return loss_value


# TODO: 6. Create an nnx.Optimizer for the sharded_ffn_model
# The optimizer will manage the sharded parameters and its own sharded state.
# When optimizer is created from a model with sharded state, its own state (e.g. momentum)
# should also be sharded appropriately.
# For nnx.Optimizer, we pass the model itself.
optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)


# TODO: 7. Execute the train_step once within the mesh context.
with mesh_ex4:
    loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)

print(f"\nComputed loss: {loss}")
# Also verify that parameters in sharded_ffn_model have been updated (not easily visible by value change without running more steps)
# but the optimizer.update call should have functioned on sharded grads and params.

Exercise 5: Solution

# Exercise 5: Solution Cell

# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.
# (Assuming they are available from Exercise 4 execution)
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
    print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 5.")
    _mesh_devices_ex4_sol = np.array(jax.devices()).reshape((2, 4))
    mesh_ex4 = Mesh(_mesh_devices_ex4_sol, axis_names=('data', 'model'))
    _rngs_ffn_init_sol = nnx.Rngs(
        params=jax.random.key(2),
        layernorm_params=jax.random.key(30),
        linear1_params=jax.random.key(40),
        linear2_params=jax.random.key(50)
    )
    with mesh_ex4:
        sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_sol, embed_dim=128, ff_dim=512)
    print("Fallback definitions complete.")


# 2. Create a dummy input batch and labels
BATCH_SIZE = 16 # Global batch size
EMBED_DIM = 128 # Must match sharded_ffn_model's input embed_dim
# Output of FFN is also EMBED_DIM. For classification, a final Linear layer to NUM_CLASSES is needed.
# For this exercise, we'll make a dummy adjustment if NUM_CLASSES doesn't match.
NUM_CLASSES = 10 # Example number of classes for dummy loss

numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)
# Labels for classification, shape (BATCH_SIZE,)
numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)

# 3. Define NamedSharding for input batch and labels
# Batch shape (BATCH_SIZE, EMBED_DIM), sharded P('data', None) -> ('data' shards BATCH_SIZE)
batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))
# Labels shape (BATCH_SIZE,), sharded P('data') -> ('data' shards BATCH_SIZE)
label_input_sharding = NamedSharding(mesh_ex4, P('data'))

# 4. Shard the input batch and labels using jax.device_put
# This is typically done inside the training loop for each new batch.
# Performing it within the mesh context ensures devices align if mesh is complex.
with mesh_ex4:
    sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)
    sharded_labels = jax.device_put(numpy_labels, label_input_sharding)

print("Sharded batch object:", sharded_batch)
print("Sharded batch sharding:", sharded_batch.sharding)
# With ('data':2, 'model':4) mesh and P('data', None) for (16, 128) batch:
# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.
# Each device buffer will have shape (8, 128).
print(f"Sharded batch per-device shape: {sharded_batch.addressable_shards[0].data.shape}")

print("\nSharded labels object:", sharded_labels)
print("Sharded labels sharding:", sharded_labels.sharding)
# With ('data':2, 'model':4) mesh and P('data') for (16,) labels:
# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.
# Each device buffer will have shape (8,).
print(f"Sharded labels per-device shape: {sharded_labels.addressable_shards[0].data.shape}")


# 5. Define the train_step function
@nnx.jit
def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
  # This loss_fn is defined inside train_step to capture 'batch' and 'labels'
  # It takes the stateful NNX model as its argument.
  def loss_fn(mdl_stateful: NNXFeedForward):
    # Forward pass through the model
    logits = mdl_stateful(batch, training=True) # model's __call__

    # The FFN output is (BATCH_SIZE_PER_DEVICE, EMBED_DIM).
    # For a typical classification loss, we'd need (BATCH_SIZE_PER_DEVICE, NUM_CLASSES).
    # This is a placeholder to make the loss calculation work.
    # In a real model, you'd have a final nnx.Linear layer projecting to NUM_CLASSES.
    current_out_features = logits.shape[-1]
    if current_out_features != NUM_CLASSES:
        # Simple (and somewhat arbitrary) projection for the sake of the exercise
        # This ensures the logits match the label dimensions for softmax_cross_entropy
        # A more realistic approach for a final layer would be needed in practice.
        # We are on a per-device shard of the batch here.
        # print(f"Logits shape before adjustment: {logits.shape}") # For debugging JIT prints
        # This projection is not well-posed for learning but allows loss computation.
        if logits.shape[0] > 0 : # check if batch per device is not zero
            # Create a dummy projection matrix on the fly (not trained)
            # This is just to make the shapes work for the loss function.
            # This is NOT how you would typically do a projection in a real model.
            dummy_projection_key = jax.random.key(99) # Fixed key for reproducibility inside JIT
            projection_matrix = jax.random.normal(dummy_projection_key, (current_out_features, NUM_CLASSES))
            projected_logits = logits @ projection_matrix
        else: # if batch per device is zero, create zero logits
            projected_logits = jnp.zeros((logits.shape[0], NUM_CLASSES), dtype=logits.dtype)

    else:
        projected_logits = logits

    # Compute loss. JAX automatically handles SPMD for this calculation
    # if inputs (logits, labels) are sharded. Gradients will also be sharded,
    # and all-reduce for gradients across data-parallel dimension is inserted by JAX.
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(projected_logits, labels))
    return loss

  # Compute loss and gradients. nnx.value_and_grad handles splitting/merging model state.
  loss_value, grads = nnx.value_and_grad(loss_fn)(model)

  # Apply gradients. The optimizer updates the model's parameters in-place (conceptually).
  # If model parameters are sharded, optimizer state (e.g., momentum) is also sharded,
  # and updates are applied in a distributed manner.
  optimizer.update(model, grads)

  return loss_value

# 6. Create an nnx.Optimizer for the sharded_ffn_model
# Pass the sharded model to the optimizer and specify what to differentiate wrt (nnx.Param).
# NNX ensures that the optimizer's state (like Adam's moments) is initialized
# with the same sharding as the parameters.
optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)
print(f"\nOptimizer created. Optimizer state sharding should match param sharding.")
# You can inspect optimizer.state to see its structure and (if concrete) sharding.


# 7. Execute the train_step once within the mesh context.
# All inputs (model params via optimizer, batch, labels) are sharded.
# JAX/XLA compiles the train_step for SPMD execution.
print(f"\nExecuting train_step within mesh: {mesh_ex4}")
with mesh_ex4:
    # JIT compilation happens on the first call
    loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)
    # Subsequent calls would be faster
    # loss_2 = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)


print(f"\nComputed loss from sharded train_step: {loss}")
# print(f"Computed loss (2nd step): {loss_2}")
# Note: The actual loss value isn't meaningful due to dummy projection and random data.
# The key is that the distributed computation ran.

Exercise 6: Preparing for Sharded Checkpointing

To save/load huge sharded models without OOM, libraries like Orbax are used. Orbax needs to know the target NamedSharding for each parameter to restore it correctly. nnx.spmd.getnamedsharding helps generate this.

Instructions:
  1. Use the shardedffnmodel and mesh_ex4 from previous exercises.
  2. Get the state structure of the model using nnx.state(shardedffnmodel).
  3. Use nnx.spmd.getnamedsharding(statestruct, meshex4) to generate the PyTree of NamedSharding objects.
  4. Print the generated NamedSharding PyTree for a few parameters (e.g., LayerNorm scale, Linear1 kernel) to inspect them. This output is what you'd pass to Orbax for restoration.
# Exercise 6: Instructions Cell

# 1. Use sharded_ffn_model and mesh_ex4
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
    print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.")
    _mesh_devices_ex6 = np.array(jax.devices()).reshape((2, 4))
    mesh_ex4 = Mesh(_mesh_devices_ex6, axis_names=('data', 'model')) # Re-assign to global mesh_ex4 if needed
    _rngs_ffn_init_ex6 = nnx.Rngs(
        params=jax.random.key(20),
        layernorm_params=jax.random.key(300),
        linear1_params=jax.random.key(400),
        linear2_params=jax.random.key(500)
    )
    with mesh_ex4: # Ensure mesh_ex4 is correctly used
        sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6, embed_dim=128, ff_dim=512)


# TODO: 2. Get the state structure of the model
# This can be from the concrete sharded model, or an abstract model from nnx.eval_shape
model_state_structure = nnx.state(sharded_ffn_model)

# TODO: 3. Generate the target NamedSharding PyTree
# This uses the .sharding PartitionSpec metadata stored in the model's state
# and combines it with the provided mesh.
target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)


# TODO: 4. Print some of the generated NamedSharding objects
print("\n--- Target NamedShardings for Checkpointing ---")
print("NamedSharding for LayerNorm scale:")
nnx.display(target_named_shardings_tree['layernorm']['scale'])

print("\nNamedSharding for Linear1 kernel:")
nnx.display(target_named_shardings_tree['linear1']['kernel'])

print("\nNamedSharding for Linear2 bias:")
nnx.display(target_named_shardings_tree['linear2']['bias'])

# These NamedSharding objects tell a checkpointing library (like Orbax)
# exactly how each parameter should be laid out across the devices upon restoration.

Exercise 6: Solution

# Exercise 6: Solution Cell

# 1. Use sharded_ffn_model and mesh_ex4
if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():
    print("Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.")
    # Ensure mesh_ex4 is the one associated with sharded_ffn_model
    _mesh_devices_ex6_sol = np.array(jax.devices()).reshape((2, 4))
    mesh_ex4 = Mesh(_mesh_devices_ex6_sol, axis_names=('data', 'model')) # Make sure this is the correct mesh
    _rngs_ffn_init_ex6_sol = nnx.Rngs(
        params=jax.random.key(201),
        layernorm_params=jax.random.key(301),
        linear1_params=jax.random.key(401),
        linear2_params=jax.random.key(501)
    )
    with mesh_ex4: # Use the correct mesh_ex4
        sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6_sol, embed_dim=128, ff_dim=512)
    print("Fallback definitions complete for Ex6.")


# 2. Get the state structure of the model
# This PyTree has the same structure as the model's parameters,
# and each leaf (parameter state) contains the .sharding PartitionSpec metadata.
model_state_structure = nnx.state(sharded_ffn_model)

# 3. Generate the target NamedSharding PyTree
# nnx.spmd.get_named_sharding combines the PartitionSpec from metadata
# with the provided 'mesh_ex4' to create a full NamedSharding object for each parameter.
target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)

# 4. Print some of the generated NamedSharding objects
print("\n--- Target NamedShardings for Checkpointing (from nnx.spmd.get_named_sharding) ---")

# For LayerNorm's scale parameter
# Original sharding metadata (PartitionSpec): P('model')
# Mesh: ('data':2, 'model':4)
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))
print("\nNamedSharding for LayerNorm scale:")
# The path to scale might be model_state_structure['layernorm']['scale'].sharding
# So target_named_shardings_tree should have a similar path.
nnx.display(target_named_shardings_tree['layernorm']['scale'])


# For Linear1's kernel parameter
# Original sharding metadata (PartitionSpec): P(None, 'model')
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P(None, 'model'))
print("\nNamedSharding for Linear1 kernel:")
nnx.display(target_named_shardings_tree['linear1']['kernel'])


# For Linear2's bias parameter
# Original sharding metadata (PartitionSpec): P('model')
# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))
print("\nNamedSharding for Linear2 bias:")
nnx.display(target_named_shardings_tree['linear2']['bias'])

# This target_named_shardings_tree would be passed to something like
# orbax.checkpoint.StandardRestore(target_named_shardings_tree)
# to tell Orbax how to reconstruct the sharded arrays when loading from a checkpoint.

Conclusion & Feedback

Congratulations on completing the exercises!

You've now practiced:

  • Using JAX sharding primitives (Mesh, PartitionSpec, NamedSharding, device_put).
  • Annotating Flax NNX modules with sharding metadata.
  • The critical sharded initialization workflow to avoid OOM errors.
  • Applying these concepts to a multi-layer NNX module.
  • Sharding input data for a distributed training step.
  • Preparing sharding information for distributed checkpointing.
  • These are foundational skills for scaling up your JAX and Flax NNX models.

    Further Learning:

    • JAX Documentation: https://jax.readthedocs.io/
    • Flax NNX Documentation: https://flax.readthedocs.io/en/latest/nnx/index.html
    • JAX SPMD Guide: https://jax.readthedocs.io/en/latest/notebooks/Distributedarraysandautomaticparallelization.html
    • Orbax (for checkpointing): https://orbax.readthedocs.io/
    • Please send us feedback at https://goo.gle/jax-training-feedback