如何在 JAX 中思考

[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinkinginjax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinkinginjax.ipynb)

JAX 提供了简单而强大的 API 用于编写加速数值计算代码,但要在 JAX 中高效工作,有时需要额外的考虑。本文档旨在帮助您从头开始了解 JAX 的工作原理,以便更有效地使用它。

JAX 与 NumPy

核心概念:
  • 为方便起见,JAX 提供了受 NumPy 启发的接口。
  • 通过鸭子类型,JAX 数组通常可以作为 NumPy 数组的直接替代品。
  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了著名而强大的 API 用于处理数值数据。为方便起见,JAX 提供了与 NumPy API 非常相似的 jax.numpy,为入门 JAX 提供了便利。几乎所有可以用 numpy 完成的操作都可以用 jax.numpy 完成:

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

除了将 np 替换为 jnp 外,代码块是相同的,结果也相同。如我们所见,JAX 数组通常可以直接替代 NumPy 数组用于绘图等任务。

数组本身是作为不同的 Python 类型实现的:

type(x_np)
type(x_jnp)

Python 的鸭子类型允许 JAX 数组和 NumPy 数组在许多地方可以互换使用。

然而,JAX 和 NumPy 数组之间有一个重要区别:JAX 数组是不可变的,这意味着一旦创建,其内容就无法更改。

以下是在 NumPy 中修改数组的示例:

# NumPy: 可变数组
x = np.arange(10)
x[0] = 10
print(x)

在 JAX 中执行相同的操作会导致错误,因为 JAX 数组是不可变的:

%xmode minimal
# JAX: 不可变数组
x = jnp.arange(10)
x[0] = 10

为了更新单个元素,JAX 提供了一种索引更新语法,它返回一个更新后的副本:

y = x.at[0].set(10)
print(x)
print(y)

NumPy, lax 和 XLA: JAX API 分层

核心概念:
  • jax.numpy 是一个提供熟悉接口的高级包装器。
  • jax.lax 是一个更严格但通常更强大的低级 API。
  • 所有 JAX 操作都是根据 XLA(加速线性代数编译器)中的操作实现的。

如果您查看 jax.numpy 的源代码,您会发现所有操作最终都以 jax.lax 中定义的函数来表示。您可以将 jax.lax 视为一个更严格但通常更强大的用于处理多维数组的 API。

例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax 不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API 隐式提升混合类型。
from jax import lax
lax.add(1, 1.0)  # jax.lax API 需要显式类型提升。

如果直接使用 jax.lax,在这种情况下您将必须显式进行类型提升:

lax.add(jnp.float32(1), 1.0)

除了这种严格性,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。

例如,考虑一维卷积,在 NumPy 中可以这样表示:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

在底层,这个 NumPy 操作被转换为由 lax.conv_general_dilated 实现的更通用的卷积:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # 注意:显式提升
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # 等效于 NumPy 中的 padding='full'
result[0, 0]

这是一种批处理卷积操作,旨在高效处理深度神经网络中常用的卷积类型。它需要更多的样板代码,但比 NumPy 提供的卷积要灵活和可扩展得多(有关 JAX 卷积的更多详细信息,请参阅 JAX 中的卷积)。

从本质上讲,所有 jax.lax 操作都是对 XLA 中操作的 Python 包装器;例如,此处的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都表示为这些基本 XLA 操作的组合,这使得即时 (JIT) 编译成为可能。

JIT 还是不 JIT

核心概念:
  • 默认情况下,JAX 按顺序逐个执行操作。
  • 使用即时 (JIT) 编译装饰器,可以一起优化和一次性运行操作序列。
  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。

所有 JAX 操作都以 XLA 表示,这使得 JAX 可以使用 XLA 编译器非常高效地执行代码块。

例如,考虑这个用 jax.numpy 操作表示的函数,它对二维矩阵的行进行归一化:

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换创建此函数的即时编译版本:

from jax import jit
norm_compiled = jit(norm)

此函数返回与原始函数相同的结果,精度达到标准浮点精度:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

但是由于编译(包括操作融合、避免分配临时数组以及许多其他技巧),在 JIT 编译的情况下,执行时间可以快几个数量级(注意使用 block_until_ready() 来考虑 JAX 的异步分派):

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

也就是说,jax.jit 确实有其局限性:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐操作模式下执行:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)

但是,如果您尝试在 jit 模式下执行它,则会返回错误:

jit(get_negatives)(x)

这是因为该函数生成的数组的形状在编译时是未知的:输出的大小取决于输入数组的值,因此与 JIT 不兼容。

JIT 机制:跟踪和静态变量

核心概念:
  • JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型的输入的影响。
  • 您不希望被跟踪的变量可以标记为静态

要有效使用 jax.jit,了解其工作原理很有用。让我们在一个 JIT 编译的函数中放入几个 print() 语句,然后调用该函数:

@jit
def f(x, y):
  print("运行 f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  结果 = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

请注意,print 语句会执行,但它们打印的不是我们传递给函数的数据,而是代表它们的跟踪器对象。

这些跟踪器对象是 jax.jit 用来提取函数指定的操作序列的东西。基本跟踪器是编码数组形状dtype 的替代品,但与值无关。然后,这个记录的计算序列可以在 XLA 中高效地应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们再次在匹配的输入上调用已编译的函数时,不需要重新编译,并且不会打印任何内容,因为结果是在已编译的 XLA 中计算的,而不是在 Python 中:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

提取的操作序列被编码在一个 JAX 表达式中,简称 jaxpr。您可以使用 jax.make_jaxpr 转换查看 jaxpr:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

注意这一点的一个后果:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于跟踪的值。例如,这会失败:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)

如果您不希望跟踪某些变量,可以将它们标记为静态以用于 JIT 编译:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  print(f'得到 {x}{neg}')
  return -x if neg else x

f(1, True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此该函数仍然按预期工作:

f(jnp.int16(2), False)

了解哪些值和操作是静态的,哪些是跟踪的,是有效使用 jax.jit 的关键部分。

静态与跟踪操作

核心概念:
  • 就像值可以是静态的或跟踪的一样,操作也可以是静态的或跟踪的。
  • 静态操作在编译时在 Python 中求值;跟踪的操作在运行时在 XLA 中编译和求值。
  • 对于希望是静态的操作,请使用 numpy;对于希望被跟踪的操作,请使用 jax.numpy

静态值和跟踪值之间的这种区别使得思考如何保持静态值的静态性变得很重要。考虑这个函数:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

这会失败,并显示一个错误,指出找到了一个跟踪器,而不是一个一维整数类型的具体值序列。让我们在该函数中添加一些 print 语句,以了解发生这种情况的原因:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # 注释掉此行以避免错误:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

请注意,虽然 x 被跟踪,但 x.shape 是一个静态值。但是,当我们在此静态值上使用 jnp.arrayjnp.prod 时,它会变成一个跟踪值,此时它不能在需要静态输入的函数(如 reshape())中使用(回想一下:数组形状必须是静态的)。

一个有用的模式是对应静态执行的操作使用 numpy(即在编译时完成),而对要跟踪的操作使用 jax.numpy(即在运行时优化和执行)。对于此函数,它可能如下所示:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

因此,JAX 程序中的一个标准约定是 import numpy as npimport jax.numpy as jnp,这样两个接口都可用,以便更好地控制操作是以静态方式(使用 numpy,在编译时一次)还是以跟踪方式(使用 jax.numpy,在运行时优化)执行。