术语表

术语表#

数组#

JAX 的 numpy.ndarray 的类似物。参见 jax.Array

CPU#

简称 中央处理器 的 CPU 是大多数计算机中可用的标准计算架构。JAX 可以在 CPU 上运行计算,但通常在 GPUTPU 上可以实现更好的性能。

设备#

用于指代 JAX 进行计算时使用的 CPUGPUTPU 的通用名称。

前向模式自动微分#

参见 JVP

函数式编程#

一种编程范式,其中程序通过应用和组合 纯函数 来定义。JAX 是为与函数式程序一起使用而设计的。

GPU#

简称 图形处理单元 的 GPU 最初专门用于与屏幕图像渲染相关的操作,但现在用途更加广泛。JAX 能够针对 GPU 进行数组的快速操作(另请参阅 CPUTPU)。

jaxpr#

缩写为 JAX 表达式 ,jaxpr 是 JAX 生成的一种计算的中间表示,它被传递给 XLA 进行编译和执行。更多讨论和示例请参见 理解 Jaxprs

JIT#

缩写为 Just In Time 编译,JAX 中的 JIT 通常指的是将数组操作编译为 XLA,最常通过使用 jax.jit() 来实现。

JVP#

简称 Jacobian 向量积,有时也称为 前向模式 自动微分。更多详情,请参阅 jacobian-vector-product。在 JAX 中,JVP 是一个通过 jax.jvp() 实现的 变换。另请参阅 VJP

原始的#

原语是 JAX 程序中使用的基本计算单元。jax.lax 中的大多数函数都表示单个原语。当在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个原语。

纯函数#

纯函数是一个其输出仅基于输入且没有副作用的函数。JAX 的 变换 模型设计用于处理纯函数。另见 函数式编程

pytree#

pytree 是一种抽象,它让 JAX 能够以统一的方式处理元组、列表、字典以及数组值的其他更一般的容器。有关更详细的讨论,请参阅 使用 pytrees

反向模式自动微分#

参见 VJP

SPMD#

简称 单程序多数据 ,它指的是一种并行计算技术,其中相同的计算(例如,神经网络的前向传播)在不同的输入数据(例如,批次中的不同输入)上并行运行在不同的设备(例如,多个TPU)上。jax.pmap() 是JAX 变换 ,实现了SPMD并行。

静态#

JIT 编译中,一个未被追踪的值(参见 Tracer)。有时也指对静态值的编译时计算。

TPU#

简称 Tensor Processing Unit,TPU 是专门为深度学习应用中使用的 N 维张量进行快速运算而设计的芯片。JAX 能够针对 TPU 进行数组的快速运算(另见 CPUGPU)。

追踪器#

一个对象,用作 JAX 数组 的占位符,以确定 Python 函数执行的操作序列。在内部,JAX 通过 jax.core.Tracer 类实现这一点。

变换#

高阶函数:即,一个以函数为输入并输出一个转换后函数的函数。JAX 中的例子包括 jax.jit()jax.vmap()jax.grad()

VJP#

简称 向量雅可比积 ,有时也称为 反向模式 自动微分。更多详情,请参见 向量雅可比积 。在 JAX 中,VJP 是一种通过 jax.vjp() 实现的 变换 。另请参见 JVP

XLA#

简称 加速线性代数,XLA 是一个针对线性代数操作的特定领域编译器,它是 即时 编译 JAX 代码的主要后端。参见 https://www.tensorflow.org/xla/

弱类型#

一种具有与Python标量相同类型提升语义的JAX数据类型;参见 弱类型