Introduction to Flax NNX

Welcome to the Flax NNX Colab Notebook! This notebook provides hands-on exercises designed to help PyTorch users transition to Flax NNX and the JAX ecosystem.

We'll cover core concepts and build simple models.

!pip install -q jax-ai-stack==2025.9.3

import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
from typing import Any, Dict, Tuple # For type hints

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
# @title Exercise 1:  Understanding Modules and Parameters (Coding Exercise)

# Instructions:
# 1. Create a simple NNX Module called `MyLinearLayer`.
# 2. It should have an `nnx.Param` called `weight` (initialized randomly with shape [input_size, output_size]).
# 3. It should have an `nnx.Param` called `bias` (initialized with zeros with shape [output_size]).
# 4. The forward pass (`__call__` method) should perform a linear transformation: `x @ self.weight.value + self.bias.value`.
# 5. Instantiate the layer with `input_size=10` and `output_size=5`.
# 6. Print the shape of the `weight` and `bias` parameters.

class MyLinearLayer(nnx.Module):
    def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):

        pass # FILL IN THIS PART

    def __call__(self, x: jax.Array):
        pass # FILL IN THIS PART

# Instantiate the layer
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(
    input_size='FILL IN THIS PART',
    output_size='FILL IN THIS PART',
    rngs=nnx.Rngs(key))

# Print the shapes of the parameters
print("Weight shape:", 'FILL IN THIS PART')
print("Bias shape:", 'FILL IN THIS PART')

# Example usage:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 1 Solution

class MyLinearLayer(nnx.Module):
    def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):
        self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))
        self.bias = nnx.Param(jnp.zeros((output_size,)))

    def __call__(self, x: jax.Array):
        return x @ self.weight.value + self.bias.value

# Instantiate the layer
key = jax.random.PRNGKey(0)
linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))

# Print the shapes of the parameters
print("Weight shape:", linear_layer.weight.value.shape)
print("Bias shape:", linear_layer.bias.value.shape)

# Example usage:
dummy_input = jnp.ones((1, 10))
output = linear_layer(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 2: State Management (Coding Exercise)
# Instructions:
# 1. Create an NNX Module called `CounterModule`.
# 2. It should have a Python instance attribute called `count` initialized to 0.
# 3. The `__call__` method should increment the `count` by 1 and return the new value.
# 4. Instantiate the module.
# 5. Call the module multiple times and print the returned value.
# 6. Use `nnx.split` and `nnx.merge` to save and load the module's state. Verify that the counter resumes from where it left off.

class CounterModule(nnx.Module):
    def __init__(self):
        pass # FILL IN THIS PART

    def __call__(self):
        pass # FILL IN THIS PART

# Instantiate the module
pass # FILL IN THIS PART.  Name it "counter"

# Call the module and print the value
print("First call:", counter())
print("Second call:", counter())

# Split the module into graphdef and state.
# Remember that state is an nnx.Variable
graphdef, state =  # FILL IN THIS PART

# Merge the graphdef and state to create a new module
new_counter =  # FILL IN THIS PART

# Call the new module and print the value
print("After split and merge, first call:", new_counter())
print("After split and merge, second call:", new_counter())
# @title Exercise 2 Solution

class CounterModule(nnx.Module):
    def __init__(self):
        self.count = 0

    def __call__(self):
        self.count += 1
        return self.count

# Instantiate the module
counter = CounterModule()

# Call the module and print the value
print("First call:", counter())
print("Second call:", counter())

# Split the module into graphdef and state
graphdef, state = nnx.split(counter, nnx.Variable)

# Merge the graphdef and state to create a new module
new_counter = nnx.merge(graphdef, state)

# Call the new module and print the value
print("After split and merge, first call:", new_counter())
print("After split and merge, second call:", new_counter())
# @title Exercise 3: Explicit Random Number Generation (Coding Exercise)

# Instructions:
# 1. Create an NNX Module called `RandomNormalLayer`.
# 2.  Its `__init__` method should receive a `size` argument defining the size of the random vector to generate.
# 3.  The `__init__` method should receive a `rngs: nnx.Rngs` argument that is used to generate a random normal tensor
#     using jax.random.normal and assign the tensor to `self.random_vector`.
# 4. The `__call__` method should return the value of `self.random_vector` (a new random normal tensor).
# 5. Instantiate the layer with a size of 10, passing in the rngs parameter with a jax.random.PRNGKey.
# 6. Call the module twice and observe that the returned values are different.

# CREATE RandomNormalLayer

# Instantiate the module
key = # USE jax.random.PRNGKey to create a new key
random_layer = RandomNormalLayer(size='SIZE HERE', rngs=nnx.Rngs(key))

# Call the module and print the value
print("First call:", random_layer())
print("Second call:", random_layer())
# @title Exercise 3 Solution

class RandomNormalLayer(nnx.Module):
   def __init__(self, size: int, *, rngs: nnx.Rngs):
     self.size = size
     self.rngs = rngs
   def __call__(self):
      self.random_vector = jax.random.normal(self.rngs.params(), (self.size,))
      return self.random_vector

# Instantiate the module
key = jax.random.PRNGKey(0)
random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))

# Call the module and print the value
print("First call:", random_layer())
print("Second call:", random_layer())
# @title Exercise 4: Building a Simple CNN (Coding Exercise)

# Instructions:
# 1. Create an NNX Module representing a simple CNN with the following layers:
#    - Convolutional layer (nnx.Conv) with 32 filters, kernel size 3, and stride 1.
#    - ReLU activation.
#    - Max pooling layer (nnx.max_pool) with window size 2 and stride 2.
#    - Flatten layer (jax.numpy.reshape).
#    - Linear layer (nnx.Linear) to map to 10 output classes.
# 2. Initialize the CNN with appropriate input and output shapes.
# 3. Perform a forward pass with a dummy input and print the output shape.

class SimpleCNN(nnx.Module):
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        self.conv = nnx.Conv('STRIDE', 'FILTERS', kernel_size=('X, X'), rngs=rngs)
        self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.conv(x)
        print(f'{x.shape = }') # For debug
        x = nnx.relu(x)
        print(f'{x.shape = }') # For debug
        x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))
        print(f'{x.shape = }') # For debug
        x = x.reshape(x.shape[0], -1)  # flatten
        print(f'{x.shape = }') # For debug
        x = self.linear(x)
        return x

# Instantiate the CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes='OUTPUT CLASSES', rngs=nnx.Rngs(key))

# Dummy input
dummy_input = jnp.ones((1, 28, 28, 1))

# Forward pass
output = cnn(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 4 Solution

class SimpleCNN(nnx.Module):
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.conv(x)
        print(f'{x.shape = }')
        x = nnx.relu(x)
        print(f'{x.shape = }')
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        print(f'{x.shape = }')
        x = x.reshape(x.shape[0], -1)  # flatten
        print(f'{x.shape = }')
        x = self.linear(x)
        return x

# Instantiate the CNN
key = jax.random.PRNGKey(0)
cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))

# Dummy input
dummy_input = jnp.ones((1, 28, 28, 1))

# Forward pass
output = cnn(dummy_input)
print("Output shape:", output.shape)
# @title Exercise 5: Training Loop with Optax (Coding Exercise)

# Instructions:
# 1. Define a simple model (e.g., a linear layer).
# 2. Create an nnx.Optimizer, making sure to specify which variable types to
#    update using the now required wrt argument (e.g., wrt=nnx.Param).
# 3. Implement a training step function that:
#    - Calculates the loss (e.g., mean squared error).
#    - Computes gradients using `nnx.value_and_grad`.
#    - Updates the model's state using `optimizer.update(model, grads)`.
# 4. Run the training loop for a few steps.

# Define a simple model
class LinearModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = 'LINEAR LAYER HERE'

    def __call__(self, x: jax.Array):
        return self.linear(x)

# Instantiate the model
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))

# Create an Optax optimizer
tx = 'OPTAX SGD HERE'
optimizer = nnx.Optimizer('WRAP THE OPTIMIZER')

# Dummy data
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])

# Training step function
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss, model

# Training loop
num_steps = 10
for i in range(num_steps):
    loss, model = train_step(model, optimizer, x, y)
    print(f"Step {i+1}, Loss: {loss}")

print("Trained model output:", model(x))
# @title Exercise 5 Solution

# Define a simple model
class LinearModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)

    def __call__(self, x: jax.Array):
        return self.linear(x)

# Instantiate the model
key = jax.random.PRNGKey(0)
model = LinearModel(rngs=nnx.Rngs(key))

# Create an Optax optimizer
tx = optax.sgd(learning_rate=0.01)
optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)

# Dummy data
x = jnp.array([[2.0]])
y = jnp.array([[4.0]])

# Training step function
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss, model

# Training loop
num_steps = 10
for i in range(num_steps):
    loss, model = train_step(model, optimizer, x, y)
    print(f"Step {i+1}, Loss: {loss}")

print("Trained model output:", model(x))

Congratulations!

You've now worked through the fundamentals of Flax NNX!

Remember to consult the official documentation for more in-depth details:

  • Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io
  • JAX: https://jax.readthedocs.io

Keep practicing, and happy JAXing!

Please send us feedback at https://goo.gle/jax-training-feedback