关键概念#
本节简要介绍JAX包的一些关键概念。
JAX 数组 (jax.Array
)#
JAX 中的默认数组实现是 jax.Array
。在许多方面,它类似于你可能从 NumPy 包中熟悉的 numpy.ndarray
类型,但它有一些重要的区别。
数组创建#
我们通常不会直接调用 jax.Array
构造函数,而是通过 JAX API 函数来创建数组。例如,jax.numpy
提供了熟悉的 NumPy 风格数组构造功能,如 jax.numpy.zeros()
、jax.numpy.linspace()
、jax.numpy.arange()
等。
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
True
如果你在代码中使用Python类型注解,jax.Array
是用于jax数组对象的适当注解(更多讨论见 jax.typing
)。
数组设备和分片#
JAX 数组对象有一个 devices
方法,可以让你检查数组内容的存储位置。在最简单的情况下,这将是单个 CPU 设备:
x.devices()
{CpuDevice(id=0)}
通常,一个数组可以跨多个设备进行 分片 ,其方式可以通过 sharding
属性进行检查:
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
这里数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备上,甚至多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅 分片计算
变换#
除了操作数组的函数外,JAX 还包括许多对 JAX 函数进行操作的 变换。这些包括
jax.vmap()
: 向量化变换;参见 自动向量化jax.grad()
: 梯度变换;参见 自动微分
以及其他几个。转换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这里是如何即时编译一个简单的 SELU 函数:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
通常你会看到为了方便使用Python的装饰器语法应用的转换:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
跟踪#
变换背后的魔法是 Tracer 的概念。Tracer 是数组对象的抽象替代品,并传递给 JAX 函数,以便提取函数编码的操作序列。
你可以在转换后的JAX代码中打印任何数组值来看到这一点;例如:
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>
打印的值不是数组 x
,而是一个 Tracer
实例,它表示 x
的基本属性,如 shape
和 dtype
。通过使用追踪的值执行函数,JAX 可以在实际执行这些操作之前确定函数编码的操作序列:像 jit()
、vmap()
和 grad()
这样的变换可以将输入操作序列映射到变换后的操作序列。
Jaxprs#
JAX 有自己的操作序列中间表示,称为 jaxpr。jaxpr(JAX 表达式 的缩写)是函数程序的简单表示,由一系列 primitive 操作组成。
例如,考虑我们上面定义的 selu
函数:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我们可以使用 jax.make_jaxpr()
工具将此函数转换为给定特定输入的 jaxpr:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558 c
e:f32[5] = sub d 1.6699999570846558
f:f32[5] = pjit[
name=_where
jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
j:f32[5] = select_n g i h
in (j,) }
] b a e
k:f32[5] = mul 1.0499999523162842 f
in (k,) }
将此与Python函数定义进行比较,我们可以看到它编码了函数所表示的操作的精确序列。我们将在后面的 JAX 内部:jaxpr 语言 中更深入地探讨jaxprs。
Pytrees#
JAX 函数和变换基本上操作于数组,但在实践中,编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在一个带有有意义键的数组字典中。JAX 不是逐个处理这些结构,而是依赖于 pytree 抽象来统一处理这些集合。
以下是一些可以被视为 pytrees 的对象示例:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX 提供了许多用于处理 PyTrees 的通用工具;例如,函数 jax.tree.map()
可以用于将一个函数映射到树的每个叶子节点,而 jax.tree.reduce()
可以用于在树的叶子节点上应用归约操作。
你可以在 使用 pytrees 教程中了解更多。