This notebook shows a simple workflow from a model which is loaded from Hugging Face into JAX, and then served using vLLM. For brevity we leave out the actual fine-tuning or other alterations in JAX, since this is covered in other tutorials. This is right on the edge of what can be done in a free Colab GPU instance, so we restart before installing vLLM to free up memory. As a bonus, this notebook contains a JAX implementation of a Llama 3.2 model, which can be interesting by itself.
!pip install -q jax-ai-stack==2025.9.3
!pip uninstall -y jax
!pip install -q jax[cuda]==0.7.0
!pip install -q vllm # We'll need it later
import jax
print(f"JAX version: {jax.__version__}")
We'll download the model weights in Safetensors format.
!huggingface-cli login
import os
from huggingface_hub import snapshot_download
model_id = "meta-llama/Llama-3.2-1B"
path_to_model_weights = os.path.join('/content', model_id)
snapshot_download(repo_id=model_id, local_dir=path_to_model_weights)
# Load the weights from the Safetensors file in Flax format
import jax
from pathlib import Path
from safetensors import safe_open
def load_safetensors():
weights = {}
safetensors_files = Path(path_to_model_weights).glob('*.safetensors')
for file in safetensors_files:
with safe_open(file, framework="flax") as f:
for key in f.keys():
print(f"Loading {key}")
weights[key] = f.get_tensor(key)
return weights
weights = load_safetensors()
# # Install the JAX AI Stack for GPU
# !pip install -q jax[cuda] jax-ai-stack
import jax
print(jax.devices())
print(jax.__version__)
from flax import nnx
from dataclasses import dataclass
import jax.numpy as jnp
@dataclass
class LlamaConfig:
def __init__(self):
self.dim = 2048
self.n_layers = 16
self.n_heads = 32
self.n_kv_heads = 8
self.head_dim = self.dim // self.n_heads
self.intermediate_size = 14336
self.vocab_size = 128256
self.multiple_of = 256
self.norm_eps = 1e-05
self.rope_theta = 500000.0
config = LlamaConfig()
class LlamaRMSNorm(nnx.Module):
def __init__(self, dim: int, rngs=None):
self.norm_weights = nnx.Param(jnp.zeros((dim,), dtype=jnp.bfloat16))
@nnx.jit()
def __call__(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(jnp.float32)
squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + config.norm_eps))
return self.norm_weights * hidden_states.astype(input_dtype)
class LlamaRotaryEmbedding(nnx.Module):
def __init__(self, dim, base=10000, rngs=None):
self.dim = dim
self.base = base
@nnx.jit()
def __call__(self, position_ids):
inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
freqs = jnp.einsum('bij,bjk->bijk', position_ids_expanded, inv_freq_expanded)
emb = jnp.concatenate([freqs, freqs], axis=-1)
cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)
sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)
return cos, sin
class LlamaAttention(nnx.Module):
def __init__(self, layer_idx, rngs=None):
self.q_proj = nnx.Linear(config.dim, config.n_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.k_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.v_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.o_proj = nnx.Linear(config.n_heads * config.head_dim, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.rotary_emb = LlamaRotaryEmbedding(config.head_dim, base=config.rope_theta, rngs=rngs)
# Alternative implementation:
# https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/examples/gemma/positional_embeddings.py#L45
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return jnp.concatenate([-x2, x1], axis=-1)
def repeat_kv(self, hidden_states, n_repeat):
batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
if n_repeat == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)
return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)
@nnx.jit()
def __call__(self, x, position_ids):
batch_size, seq_len, _ = x.shape
query = self.q_proj(x).reshape(batch_size, seq_len, config.n_heads, config.head_dim).transpose((0, 2, 1, 3))
key = self.k_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
value = self.v_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
# Assuming batch_size=1
cos, sin = self.rotary_emb(position_ids[0])
query, key = self.apply_rotary_pos_emb(query, key, cos, sin)
key = self.repeat_kv(key, config.n_heads // config.n_kv_heads)
value = self.repeat_kv(value, config.n_heads // config.n_kv_heads)
attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))
attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(config.head_dim)).astype(jnp.bfloat16)
attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)
attn_output = jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return output
class LlamaMLP(nnx.Module):
def __init__(self, layer_idx, rngs=None):
self.gate_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.up_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.down_proj = nnx.Linear(config.intermediate_size, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
@nnx.jit()
def __call__(self, x):
return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))
class LlamaTransformerBlock(nnx.Module):
def __init__(self, layer_idx, rngs=None):
self.input_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)
self.attention = LlamaAttention(layer_idx=layer_idx, rngs=rngs)
self.post_attention_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)
self.mlp = LlamaMLP(layer_idx=layer_idx, rngs=rngs)
@nnx.jit()
def __call__(self, x, position_ids):
residual = x
x = self.input_layernorm(x)
x = self.attention(x, position_ids)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class LlamaForCausalLM(nnx.Module):
def __init__(self, rngs=None):
self.token_embed = nnx.Embed(num_embeddings=config.vocab_size, features=config.dim, param_dtype=jnp.bfloat16, rngs=rngs)
self.layers = [LlamaTransformerBlock(layer_idx=idx, rngs=rngs) for idx in range(config.n_layers)]
self.lm_head = nnx.Linear(config.dim, config.vocab_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)
self.norm = LlamaRMSNorm(dim=config.head_dim, rngs=rngs)
@nnx.jit()
def __call__(self, input_ids, position_ids):
assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
x = self.token_embed(input_ids)
for layer in self.layers:
x = layer(x, position_ids)
x = self.norm(x)
logits = self.lm_head(x)
return logits
model = LlamaForCausalLM(rngs=nnx.Rngs(0))
state = nnx.state(model)
nnx.display(state) # This can be very useful
Because of differences in the layer definitions between PyTorch and JAX/Flax NNX we need to alter the shapes of some of the weights. Here's a quick summary:
[outC, inC, kH, kW] to [kH, kW, inC, outC][outC, inC, kH, kW] -> [kH, kW, inC, outC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
# This is specific to the format of a Hugging Face Llama 3.2 checkpoint
def update_from_HF_checkpoint(state: nnx.State, weights: dict) -> None:
for wholekey in weights:
keys = wholekey.split('.')
if keys[1] == 'layers':
if keys[3] == 'self_attn':
keys[3] = 'attention'
if keys[1] == 'layers' and keys[3] == 'attention':
state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T
elif keys[1] == 'layers' and keys[3] == 'mlp':
state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T
elif keys[1] == 'layers' and keys[3] == 'input_layernorm':
state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]
elif keys[1] == 'layers' and keys[3] == 'post_attention_layernorm':
state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]
elif keys[1] == 'embed_tokens':
state['token_embed'].embedding.value = weights[wholekey]
state['lm_head'].kernel.value = weights[wholekey].T
elif keys[1] == 'norm':
state['norm'].norm_weights.value = weights[wholekey]
update_from_HF_checkpoint(state, weights)
nnx.update(model, state)
# nnx.display(state)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_text = "The capital of Japan is"
input_ids = tokenizer(input_text, return_tensors="jax")["input_ids"]
position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])
for _ in range(15):
logits = model(input_ids, position_ids)
next_token = jnp.argmax(logits[:, -1, :], axis=-1)
input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)
position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])
print(f"Generated token: {next_token[0]}")
print(tokenizer.decode(input_ids[0]))
We loaded up our JAX model, and although we didn't make any changes to it in this notebook, in real life we may have done some fine-tuning, alignment, etc. So now we need to get our updated model so that we can serve it with vLLM.
state = nnx.state(model) # We already have it, but just to illustrate
# This is specific to the format of a Hugging Face Llama 3.2 checkpoint
def model_state_to_HF_weights(state: nnx.State) -> dict:
global weights
weights_dict = {}
weights_dict['model.embed_tokens.weight'] = state['token_embed'].embedding.value
weights_dict['model.norm.weight'] = state['norm'].norm_weights.value
for idx, layer in enumerate(state['layers'].values()):
weights_dict[f'model.layers.{idx}.input_layernorm.weight'] = layer['input_layernorm'].norm_weights.value
weights_dict[f'model.layers.{idx}.post_attention_layernorm.weight'] = layer['post_attention_layernorm'].norm_weights.value
weights_dict[f'model.layers.{idx}.self_attn.k_proj.weight'] = layer['attention']['k_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.self_attn.o_proj.weight'] = layer['attention']['o_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.self_attn.q_proj.weight'] = layer['attention']['q_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.self_attn.v_proj.weight'] = layer['attention']['v_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.mlp.down_proj.weight'] = layer['mlp']['down_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.mlp.gate_proj.weight'] = layer['mlp']['gate_proj'].kernel.value.T
weights_dict[f'model.layers.{idx}.mlp.up_proj.weight'] = layer['mlp']['up_proj'].kernel.value.T
return weights_dict
new_weights = model_state_to_HF_weights(state)
import torch
import numpy as np
# vLLM wants the weight dictionary flattened
def flatten_weight_dict(torch_params, prefix=""):
flat_params = {}
for key, value in torch_params.items():
new_key = f"{prefix}{key}" if prefix else key
if isinstance(value, dict):
flat_params.update(flatten_weight_dict(value, new_key + "."))
else:
flat_params[new_key] = value
return flat_params
servable_weights = flatten_weight_dict(new_weights)
# Replace the old model with the new model. Note that we could also
# keep the old and save the new model to a new directory
from safetensors.flax import save_file
save_file(servable_weights, path_to_model_weights + '/model.safetensors')
We're right on the edge of our GPU memory for a T4 Colab instance.
While safetensors is a required format for the model's weights, vLLM has two other critical requirements that determine compatibility.
.safetensors file only contains the model's weights (the numerical parameters). To function, a model also needs its configuration and tokenizer files. When you point vLLM to a model, it expects a complete directory (or a Hugging Face repository identifier) that includes:
config.json: Defines the model's architecture, size, and other essential parameters. vLLM reads this first to check for compatibility.tokenizer.json (and related files): Defines how to convert text into tokens that the model can understand.model.safetensors (or sharded versions): The file(s) containing the actual model weights.Yes! Check out the [instructions here](https://docs.vllm.ai/en/latest/models/supported_models.html#custom-models to learn how to serve custom models.
!nvcc --version
%env CUDA_HOME=/usr/local/cuda-12.5
!pip install -q vllm
# Need to restore these after restarting the session
import os
model_id = "meta-llama/Llama-3.2-1B"
path_to_model_weights = os.path.join('/content', model_id)
# Load the model into vLLM
from vllm import LLM, SamplingParams
llm = LLM(model=path_to_model_weights, load_format="safetensors", dtype="half")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {generated_text}")