Serving a JAX model with vLLM and GPU

This notebook shows a simple workflow from a model which is loaded from Hugging Face into JAX, and then served using vLLM. For brevity we leave out the actual fine-tuning or other alterations in JAX, since this is covered in other tutorials. This is right on the edge of what can be done in a free Colab GPU instance, so we restart before installing vLLM to free up memory. As a bonus, this notebook contains a JAX implementation of a Llama 3.2 model, which can be interesting by itself.

Do all the Pips

Let's get the downloads out of the way.
!pip install -q jax-ai-stack==2025.9.3
!pip uninstall -y jax
!pip install -q jax[cuda]==0.7.0
!pip install -q vllm # We'll need it later

import jax
print(f"JAX version: {jax.__version__}")

Hugging Face

Download the model from Hugging Face

We'll download the model weights in Safetensors format.

!huggingface-cli login
import os
from huggingface_hub import snapshot_download

model_id = "meta-llama/Llama-3.2-1B"
path_to_model_weights = os.path.join('/content', model_id)

snapshot_download(repo_id=model_id, local_dir=path_to_model_weights)
# Load the weights from the Safetensors file in Flax format

import jax
from pathlib import Path
from safetensors import safe_open

def load_safetensors():
  weights = {}
  safetensors_files = Path(path_to_model_weights).glob('*.safetensors')

  for file in safetensors_files:
    with safe_open(file, framework="flax") as f:
      for key in f.keys():
        print(f"Loading {key}")
        weights[key] = f.get_tensor(key)
  return weights

weights = load_safetensors()

Llama 3.2-1B JAX Implementation

# # Install the JAX AI Stack for GPU
# !pip install -q jax[cuda] jax-ai-stack

import jax
print(jax.devices())
print(jax.__version__)
from flax import nnx
from dataclasses import dataclass
import jax.numpy as jnp

@dataclass
class LlamaConfig:
  def __init__(self):
    self.dim = 2048
    self.n_layers = 16
    self.n_heads = 32
    self.n_kv_heads = 8
    self.head_dim = self.dim // self.n_heads
    self.intermediate_size = 14336
    self.vocab_size = 128256
    self.multiple_of = 256
    self.norm_eps = 1e-05
    self.rope_theta = 500000.0

config = LlamaConfig()

class LlamaRMSNorm(nnx.Module):

  def __init__(self, dim: int, rngs=None):
    self.norm_weights = nnx.Param(jnp.zeros((dim,), dtype=jnp.bfloat16))

  @nnx.jit()
  def __call__(self, hidden_states):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.astype(jnp.float32)
    squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
    hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + config.norm_eps))
    return self.norm_weights * hidden_states.astype(input_dtype)

class LlamaRotaryEmbedding(nnx.Module):

  def __init__(self, dim, base=10000, rngs=None):
    self.dim = dim
    self.base = base

  @nnx.jit()
  def __call__(self, position_ids):
    inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
    inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
    position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
    freqs = jnp.einsum('bij,bjk->bijk', position_ids_expanded, inv_freq_expanded)
    emb = jnp.concatenate([freqs, freqs], axis=-1)
    cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)
    sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)
    return cos, sin

class LlamaAttention(nnx.Module):

  def __init__(self, layer_idx, rngs=None):
    self.q_proj = nnx.Linear(config.dim, config.n_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.k_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.v_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.o_proj = nnx.Linear(config.n_heads * config.head_dim, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.rotary_emb = LlamaRotaryEmbedding(config.head_dim, base=config.rope_theta, rngs=rngs)

  # Alternative implementation:
  # https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/examples/gemma/positional_embeddings.py#L45
  def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
    cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
    sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
    q_embed = (q * cos) + (self.rotate_half(q) * sin)
    k_embed = (k * cos) + (self.rotate_half(k) * sin)
    return q_embed, k_embed

  def rotate_half(self, x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate([-x2, x1], axis=-1)

  def repeat_kv(self, hidden_states, n_repeat):
    batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
    if n_repeat == 1:
      return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)
    return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)

  @nnx.jit()
  def __call__(self, x, position_ids):
    batch_size, seq_len, _ = x.shape
    query = self.q_proj(x).reshape(batch_size, seq_len, config.n_heads, config.head_dim).transpose((0, 2, 1, 3))
    key = self.k_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
    value = self.v_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
    # Assuming batch_size=1
    cos, sin = self.rotary_emb(position_ids[0])
    query, key = self.apply_rotary_pos_emb(query, key, cos, sin)

    key = self.repeat_kv(key, config.n_heads // config.n_kv_heads)
    value = self.repeat_kv(value, config.n_heads // config.n_kv_heads)

    attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))
    attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(config.head_dim)).astype(jnp.bfloat16)
    attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)
    attn_output = jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)
    output = self.o_proj(attn_output)
    return output

class LlamaMLP(nnx.Module):

  def __init__(self, layer_idx, rngs=None):
    self.gate_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.up_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.down_proj = nnx.Linear(config.intermediate_size, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)

  @nnx.jit()
  def __call__(self, x):
    return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))

class LlamaTransformerBlock(nnx.Module):

  def __init__(self, layer_idx, rngs=None):
    self.input_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)
    self.attention = LlamaAttention(layer_idx=layer_idx, rngs=rngs)
    self.post_attention_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)
    self.mlp = LlamaMLP(layer_idx=layer_idx, rngs=rngs)

  @nnx.jit()
  def __call__(self, x, position_ids):
    residual = x
    x = self.input_layernorm(x)
    x = self.attention(x, position_ids)
    x = residual + x

    residual = x
    x = self.post_attention_layernorm(x)
    x = self.mlp(x)
    x = residual + x
    return x

class LlamaForCausalLM(nnx.Module):

  def __init__(self, rngs=None):
    self.token_embed = nnx.Embed(num_embeddings=config.vocab_size, features=config.dim, param_dtype=jnp.bfloat16, rngs=rngs)

    self.layers = [LlamaTransformerBlock(layer_idx=idx, rngs=rngs) for idx in range(config.n_layers)]
    self.lm_head = nnx.Linear(config.dim, config.vocab_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
    self.norm = LlamaRMSNorm(dim=config.head_dim, rngs=rngs)

  @nnx.jit()
  def __call__(self, input_ids, position_ids):
    assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
    x = self.token_embed(input_ids)
    for layer in self.layers:
        x = layer(x, position_ids)
    x = self.norm(x)
    logits = self.lm_head(x)
    return logits
model = LlamaForCausalLM(rngs=nnx.Rngs(0))
state = nnx.state(model)
nnx.display(state) # This can be very useful

Map the PyTorch weights to Flax NNX

Because of differences in the layer definitions between PyTorch and JAX/Flax NNX we need to alter the shapes of some of the weights. Here's a quick summary:

  • Linear (FC): Transpose
  • Convolutions: Transpose from [outC, inC, kH, kW] to [kH, kW, inC, outC]
  • [outC, inC, kH, kW] -> [kH, kW, inC, outC]

    kernel = jnp.transpose(kernel, (2, 3, 1, 0))
    • Convolutions and FC Layers:
    • We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C H W] before being fed to the fc layers. When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W].
      • BatchNorm: No change
# This is specific to the format of a Hugging Face Llama 3.2 checkpoint

def update_from_HF_checkpoint(state: nnx.State, weights: dict) -> None:
  for wholekey in weights:
    keys = wholekey.split('.')
    if keys[1] == 'layers':
      if keys[3] == 'self_attn':
        keys[3] = 'attention'
      if keys[1] == 'layers' and keys[3] == 'attention':
        state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T
      elif keys[1] == 'layers' and keys[3] == 'mlp':
        state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T
      elif keys[1] == 'layers' and keys[3] == 'input_layernorm':
        state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]
      elif keys[1] == 'layers' and keys[3] == 'post_attention_layernorm':
        state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]
    elif keys[1] == 'embed_tokens':
      state['token_embed'].embedding.value = weights[wholekey]
      state['lm_head'].kernel.value = weights[wholekey].T
    elif keys[1] == 'norm':
      state['norm'].norm_weights.value = weights[wholekey]

update_from_HF_checkpoint(state, weights)
nnx.update(model, state)
# nnx.display(state)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)
input_text = "The capital of Japan is"

input_ids = tokenizer(input_text, return_tensors="jax")["input_ids"]
position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])

for _ in range(15):
  logits = model(input_ids, position_ids)
  next_token = jnp.argmax(logits[:, -1, :], axis=-1)
  input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)
  position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])
  print(f"Generated token: {next_token[0]}")

print(tokenizer.decode(input_ids[0]))

Get the updated model for serving

We loaded up our JAX model, and although we didn't make any changes to it in this notebook, in real life we may have done some fine-tuning, alignment, etc. So now we need to get our updated model so that we can serve it with vLLM.

state = nnx.state(model) # We already have it, but just to illustrate

# This is specific to the format of a Hugging Face Llama 3.2 checkpoint

def model_state_to_HF_weights(state: nnx.State) -> dict:
  global weights

  weights_dict = {}
  weights_dict['model.embed_tokens.weight'] = state['token_embed'].embedding.value
  weights_dict['model.norm.weight'] = state['norm'].norm_weights.value

  for idx, layer in enumerate(state['layers'].values()):
    weights_dict[f'model.layers.{idx}.input_layernorm.weight'] = layer['input_layernorm'].norm_weights.value
    weights_dict[f'model.layers.{idx}.post_attention_layernorm.weight'] = layer['post_attention_layernorm'].norm_weights.value
    weights_dict[f'model.layers.{idx}.self_attn.k_proj.weight'] = layer['attention']['k_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.self_attn.o_proj.weight'] = layer['attention']['o_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.self_attn.q_proj.weight'] = layer['attention']['q_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.self_attn.v_proj.weight'] = layer['attention']['v_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.mlp.down_proj.weight'] = layer['mlp']['down_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.mlp.gate_proj.weight'] = layer['mlp']['gate_proj'].kernel.value.T
    weights_dict[f'model.layers.{idx}.mlp.up_proj.weight'] = layer['mlp']['up_proj'].kernel.value.T
  return weights_dict

new_weights = model_state_to_HF_weights(state)

Now convert the new weights back to Safetensors in preparation for serving

import torch
import numpy as np

# vLLM wants the weight dictionary flattened
def flatten_weight_dict(torch_params, prefix=""):
    flat_params = {}
    for key, value in torch_params.items():
        new_key = f"{prefix}{key}" if prefix else key
        if isinstance(value, dict):
            flat_params.update(flatten_weight_dict(value, new_key + "."))
        else:
            flat_params[new_key] = value
    return flat_params

servable_weights = flatten_weight_dict(new_weights)
# Replace the old model with the new model.  Note that we could also
# keep the old and save the new model to a new directory
from safetensors.flax import save_file
save_file(servable_weights, path_to_model_weights + '/model.safetensors')

Serving with vLLM

Runtime > Restart session to free memory

We're right on the edge of our GPU memory for a T4 Colab instance.

Which models can you serve with vLLM?

While safetensors is a required format for the model's weights, vLLM has two other critical requirements that determine compatibility.

Model Architecture is Key

The most important factor is the model's architecture. vLLM achieves its high speed by using custom, highly-optimized compute kernels for specific transformer architectures (like Llama, Mixtral, Gemma, Phi-3, etc.). Supported Architectures Only: If the model's architecture is not on vLLM's list of supported models, vLLM will not know how to load or run it, regardless of the file format. However vLLM can also run custom models, see below. Checking Compatibility: You can check a model's architecture in its config.json file under the "architectures" or "modeltype" field and compare it against the official vLLM supported models list.

More Than Just Weights

A .safetensors file only contains the model's weights (the numerical parameters). To function, a model also needs its configuration and tokenizer files. When you point vLLM to a model, it expects a complete directory (or a Hugging Face repository identifier) that includes:
  • config.json: Defines the model's architecture, size, and other essential parameters. vLLM reads this first to check for compatibility.
    • tokenizer.json (and related files): Defines how to convert text into tokens that the model can understand.
      • model.safetensors (or sharded versions): The file(s) containing the actual model weights.
      • Can I serve a model not in the supported models list?

        Yes! Check out the [instructions here](https://docs.vllm.ai/en/latest/models/supported_models.html#custom-models to learn how to serve custom models.

        CUDA

!nvcc --version
%env CUDA_HOME=/usr/local/cuda-12.5

Install vLLM

!pip install -q vllm

Serve the model with vLLM

# Need to restore these after restarting the session
import os

model_id = "meta-llama/Llama-3.2-1B"
path_to_model_weights = os.path.join('/content', model_id)
# Load the model into vLLM
from vllm import LLM, SamplingParams

llm = LLM(model=path_to_model_weights, load_format="safetensors", dtype="half")
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {generated_text}")