Welcome to the Flax NNX Colab Notebook! This notebook provides hands-on exercises designed to help PyTorch users transition to Flax NNX and the JAX ecosystem.
We'll cover core concepts and build simple models.
!pip install -q jax-ai-stack==2025.9.3
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
from typing import Any, Dict, Tuple # For type hints
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
# @title Exercise 1: Understanding Modules and Parameters (Coding Exercise)
# Instructions:
# 1. Create a simple NNX Module called `MyLinearLayer`.
# 2. It should have an `nnx.Param` called `weight` (initialized randomly with shape [input_size, output_size]).
# 3. It should have an `nnx.Param` called `bias` (initialized with zeros with shape [output_size]).
# 4. The forward pass (`__call__` method) should perform a linear transformation: `x @ self.weight.value + self.bias.value`.
# 5. Instantiate the layer with `input_size=10` and `output_size=5`.
# 6. Print the shape of the `weight` and `bias` parameters.
class MyLinearLayer(nnx.Module):
def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
pass # FILL IN THIS PART
def __call__(self, x: jax.Array):
pass # FILL IN THIS PART
# Instantiate the layer
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(
input_size='FILL IN THIS PART',
output_size='FILL IN THIS PART',
rngs=nnx.Rngs(key))
# Print the shapes of the parameters
print("Weight shape:", 'FILL IN THIS PART')
print("Bias shape:", 'FILL IN THIS PART')
# Example usage:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 1 Solution
class MyLinearLayer(nnx.Module):
def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))
self.bias = nnx.Param(jnp.zeros((output_size,)))
def __call__(self, x: jax.Array):
return x @ self.weight.value + self.bias.value
# Instantiate the layer
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))
# Print the shapes of the parameters
print("Weight shape:", linear_layer.weight.value.shape)
print("Bias shape:", linear_layer.bias.value.shape)
# Example usage:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 2: State Management (Coding Exercise)
# Instructions:
# 1. Create an NNX Module called `CounterModule`.
# 2. It should have a Python instance attribute called `count` initialized to 0.
# 3. The `__call__` method should increment the `count` by 1 and return the new value.
# 4. Instantiate the module.
# 5. Call the module multiple times and print the returned value.
# 6. Use `nnx.split` and `nnx.merge` to save and load the module's state. Verify that the counter resumes from where it left off.
class CounterModule(nnx.Module):
def __init__(self):
pass # FILL IN THIS PART
def __call__(self):
pass # FILL IN THIS PART
# Instantiate the module
pass # FILL IN THIS PART. Name it "counter"
# Call the module and print the value
print("First call:", counter())
print("Second call:", counter())
# Split the module into graphdef and state.
# Remember that state is an nnx.Variable
graphdef, state = # FILL IN THIS PART
# Merge the graphdef and state to create a new module
new_counter = # FILL IN THIS PART
# Call the new module and print the value
print("After split and merge, first call:", new_counter())
print("After split and merge, second call:", new_counter())
# @title Exercise 2 Solution
class CounterModule(nnx.Module):
def __init__(self):
self.count = 0
def __call__(self):
self.count += 1
return self.count
# Instantiate the module
counter = CounterModule()
# Call the module and print the value
print("First call:", counter())
print("Second call:", counter())
# Split the module into graphdef and state
graphdef, state = nnx.split(counter, nnx.Variable)
# Merge the graphdef and state to create a new module
new_counter = nnx.merge(graphdef, state)
# Call the new module and print the value
print("After split and merge, first call:", new_counter())
print("After split and merge, second call:", new_counter())
# @title Exercise 3: Explicit Random Number Generation (Coding Exercise)
# Instructions:
# 1. Create an NNX Module called `RandomNormalLayer`.
# 2. Its `__init__` method should receive a `size` argument defining the size of the random vector to generate.
# 3. The `__init__` method should receive a `rngs: nnx.Rngs` argument that is used to generate a random normal tensor
# using jax.random.normal and assign the tensor to `self.random_vector`.
# 4. The `__call__` method should return the value of `self.random_vector` (a new random normal tensor).
# 5. Instantiate the layer with a size of 10, passing in the rngs parameter with a jax.random.PRNGKey.
# 6. Call the module twice and observe that the returned values are different.
# CREATE RandomNormalLayer
# Instantiate the module
key = # USE jax.random.PRNGKey to create a new key
random_layer = RandomNormalLayer(size='SIZE HERE', rngs=nnx.Rngs(key))
# Call the module and print the value
print("First call:", random_layer())
print("Second call:", random_layer())
# @title Exercise 3 Solution
class RandomNormalLayer(nnx.Module):
def __init__(self, size: int, *, rngs: nnx.Rngs):
self.size = size
self.rngs = rngs
def __call__(self):
self.random_vector = jax.random.normal(self.rngs.params(), (self.size,))
return self.random_vector
# Instantiate the module
key = jax.random.PRNGKey(0)
random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))
# Call the module and print the value
print("First call:", random_layer())
print("Second call:", random_layer())
# @title Exercise 4: Building a Simple CNN (Coding Exercise)
# Instructions:
# 1. Create an NNX Module representing a simple CNN with the following layers:
# - Convolutional layer (nnx.Conv) with 32 filters, kernel size 3, and stride 1.
# - ReLU activation.
# - Max pooling layer (nnx.max_pool) with window size 2 and stride 2.
# - Flatten layer (jax.numpy.reshape).
# - Linear layer (nnx.Linear) to map to 10 output classes.
# 2. Initialize the CNN with appropriate input and output shapes.
# 3. Perform a forward pass with a dummy input and print the output shape.
class SimpleCNN(nnx.Module):
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
self.conv = nnx.Conv('STRIDE', 'FILTERS', kernel_size=('X, X'), rngs=rngs)
self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.conv(x)
print(f'{x.shape = }') # For debug
x = nnx.relu(x)
print(f'{x.shape = }') # For debug
x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))
print(f'{x.shape = }') # For debug
x = x.reshape(x.shape[0], -1) # flatten
print(f'{x.shape = }') # For debug
x = self.linear(x)
return x
# Instantiate the CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes='OUTPUT CLASSES', rngs=nnx.Rngs(key))
# Dummy input
dummy_input = jnp.ones((1, 28, 28, 1))
# Forward pass
output = cnn(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 4 Solution
class SimpleCNN(nnx.Module):
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.conv(x)
print(f'{x.shape = }')
x = nnx.relu(x)
print(f'{x.shape = }')
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
print(f'{x.shape = }')
x = x.reshape(x.shape[0], -1) # flatten
print(f'{x.shape = }')
x = self.linear(x)
return x
# Instantiate the CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))
# Dummy input
dummy_input = jnp.ones((1, 28, 28, 1))
# Forward pass
output = cnn(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 5: Training Loop with Optax (Coding Exercise)
# Instructions:
# 1. Define a simple model (e.g., a linear layer).
# 2. Create an nnx.Optimizer, making sure to specify which variable types to
# update using the now required wrt argument (e.g., wrt=nnx.Param).
# 3. Implement a training step function that:
# - Calculates the loss (e.g., mean squared error).
# - Computes gradients using `nnx.value_and_grad`.
# - Updates the model's state using `optimizer.update(model, grads)`.
# 4. Run the training loop for a few steps.
# Define a simple model
class LinearModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = 'LINEAR LAYER HERE'
def __call__(self, x: jax.Array):
return self.linear(x)
# Instantiate the model
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))
# Create an Optax optimizer
tx = 'OPTAX SGD HERE'
optimizer = nnx.Optimizer('WRAP THE OPTIMIZER')
# Dummy data
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])
# Training step function
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss, model
# Training loop
num_steps = 10
for i in range(num_steps):
loss, model = train_step(model, optimizer, x, y)
print(f"Step {i+1}, Loss: {loss}")
print("Trained model output:", model(x))
# @title Exercise 5 Solution
# Define a simple model
class LinearModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)
def __call__(self, x: jax.Array):
return self.linear(x)
# Instantiate the model
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))
# Create an Optax optimizer
tx = optax.sgd(learning_rate=0.01)
optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)
# Dummy data
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])
# Training step function
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss, model
# Training loop
num_steps = 10
for i in range(num_steps):
loss, model = train_step(model, optimizer, x, y)
print(f"Step {i+1}, Loss: {loss}")
print("Trained model output:", model(x))
Remember to consult the official documentation for more in-depth details:
Keep practicing, and happy JAXing!
Please send us feedback at https://goo.gle/jax-training-feedback