术语表#
- 数组#
JAX 的
numpy.ndarray
的类似物。参见jax.Array
。- CPU#
简称 中央处理器 的 CPU 是大多数计算机中可用的标准计算架构。JAX 可以在 CPU 上运行计算,但通常在 GPU 和 TPU 上可以实现更好的性能。
- 设备#
- 前向模式自动微分#
参见 JVP
- 函数式编程#
一种编程范式,其中程序通过应用和组合 纯函数 来定义。JAX 是为与函数式程序一起使用而设计的。
- GPU#
简称 图形处理单元 的 GPU 最初专门用于与屏幕图像渲染相关的操作,但现在用途更加广泛。JAX 能够针对 GPU 进行数组的快速操作(另请参阅 CPU 和 TPU)。
- jaxpr#
缩写为 JAX 表达式 ,jaxpr 是 JAX 生成的一种计算的中间表示,它被传递给 XLA 进行编译和执行。更多讨论和示例请参见 理解 Jaxprs。
- JIT#
缩写为 Just In Time 编译,JAX 中的 JIT 通常指的是将数组操作编译为 XLA,最常通过使用
jax.jit()
来实现。- JVP#
简称 Jacobian 向量积,有时也称为 前向模式 自动微分。更多详情,请参阅 jacobian-vector-product。在 JAX 中,JVP 是一个通过
jax.jvp()
实现的 变换。另请参阅 VJP。- 原始的#
原语是 JAX 程序中使用的基本计算单元。
jax.lax
中的大多数函数都表示单个原语。当在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个原语。- 纯函数#
- pytree#
pytree 是一种抽象,它让 JAX 能够以统一的方式处理元组、列表、字典以及数组值的其他更一般的容器。有关更详细的讨论,请参阅 使用 pytrees。
- 反向模式自动微分#
参见 VJP。
- SPMD#
简称 单程序多数据 ,它指的是一种并行计算技术,其中相同的计算(例如,神经网络的前向传播)在不同的输入数据(例如,批次中的不同输入)上并行运行在不同的设备(例如,多个TPU)上。
jax.pmap()
是JAX 变换 ,实现了SPMD并行。- 静态#
- TPU#
简称 Tensor Processing Unit,TPU 是专门为深度学习应用中使用的 N 维张量进行快速运算而设计的芯片。JAX 能够针对 TPU 进行数组的快速运算(另见 CPU 和 GPU)。
- 追踪器#
一个对象,用作 JAX 数组 的占位符,以确定 Python 函数执行的操作序列。在内部,JAX 通过
jax.core.Tracer
类实现这一点。- 变换#
高阶函数:即,一个以函数为输入并输出一个转换后函数的函数。JAX 中的例子包括
jax.jit()
、jax.vmap()
和jax.grad()
。- VJP#
简称 向量雅可比积 ,有时也称为 反向模式 自动微分。更多详情,请参见 向量雅可比积 。在 JAX 中,VJP 是一种通过
jax.vjp()
实现的 变换 。另请参见 JVP 。- XLA#
简称 加速线性代数,XLA 是一个针对线性代数操作的特定领域编译器,它是 即时 编译 JAX 代码的主要后端。参见 https://www.tensorflow.org/xla/。
- 弱类型#
一种具有与Python标量相同类型提升语义的JAX数据类型;参见 弱类型。