Omnistaging#

mattjj@ 2020年9月25日

这更像是一个升级指南,而不是设计文档。

内容#

太长不看#

发生了什么?#

JAX 的追踪基础设施有一个名为“omnistaging”的变更(google/jax#3370),在 jax==0.2.0 版本中被启用。这一变更提升了内存性能、追踪执行时间,并简化了 jAX 的内部机制,但可能会导致一些现有代码无法正常运行。通常,代码无法运行是由于代码本身存在问题,因此从长远来看,修复这些错误是最好的选择,但作为临时解决方案,也可以禁用 omnistaging。我们很乐意帮助您进行修复!

我如何知道 omnistaging 是否破坏了我的代码?#

判断是否由 omnistaging 引起的最简单方法是禁用 omnistaging,看看问题是否消失。请参阅下面的 开启 omnistaging 时可能出现的问题 部分。

我目前如何禁用 omnistaging?#

注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及以上,无法禁用 omnistaging

暂时可以通过以下方式禁用 omnistaging:

  1. 将shell环境变量 JAX_OMNISTAGING 设置为假值;

  2. 如果你的代码使用 absl 解析标志,请将布尔标志 jax_omnistaging 设置为假值。

  3. 在你的主文件顶部使用以下语句:

jax.config.disable_omnistaging()

如何修复由omnistaging暴露的错误?#

到目前为止,使用 jax.numpy 计算形状值或其他跟踪时恒定值是 omnistaging 最常见的问题。请参阅下面的代码块以获取快速示例,有关详细信息以及其他问题的完整说明,请参阅 启用 omnistaging 时可能出现的问题 部分。

而不是这样:

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

这样做:

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

与其将 jax.numpy 视为 numpy 的直接替代品,现在更好的做法是仅在您希望在加速器(如您的 GPU)上执行计算时使用 jax.numpy 操作。

什么是“omnistaging”以及它为什么有用?#

Omnistaging 是 JAX 核心升级的名称,旨在将更多计算从逐操作的 Python 转移到 XLA,并避免在 jitpmap 和控制流原语中进行任何“跟踪时恒定折叠”。因此,Omnistaging 通过减少跟踪期间的碎片化和为 XLA 生成更少的编译时大常量,从而提高了 JAX 的内存性能(有时显著提高)。它还可以通过消除跟踪时的逐操作执行来提高跟踪性能。此外,Omnistaging 简化了 JAX 核心内部结构,修复了许多长期存在的错误,并为即将到来的重要功能奠定了基础。

“omnistaging”这个名字意味着尽可能地进行所有阶段的处理。

玩具示例#

jitpmap 这样的 JAX 变换会将计算分阶段交给 XLA 处理。也就是说,我们将这些变换应用于包含多个基本操作的函数,使得这些操作不再从 Python 中逐个执行,而是成为一次端到端优化的 XLA 计算的一部分。

但究竟哪些操作会被分阶段执行?在 omnistaging 之前,JAX 仅基于数据依赖性来分阶段执行计算。以下是一个示例函数,紧接着是它在 omnistaging 更改 之前 分阶段生成的 XLA HLO 程序:

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

注意,add 操作没有被暂存。相反,我们只看到一个乘法。

以下是该函数在 omnistaging 更改后的生成的 HLO:

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

稍微少一些的玩具示例#

以下是一个较少玩具化的示例,当我们想要创建布尔掩码时,这在实践中可能会出现:

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

Before 全阶段之前:

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

select 操作被分阶段执行,但构建常量 mask 的操作没有。与其被分阶段执行,构建 mask 的操作在 Python 跟踪时逐个执行,XLA 只看到一个表示 mask 值的编译时常量 constant.1。这很不幸,因为如果我们分阶段执行了构建 mask 的操作,XLA 本可以将它们融合到 select 中,并完全避免物化结果。结果是我们最终浪费了内存来存储一个可能很大的常量,浪费时间分派多个未融合的逐个 XLA 计算,甚至可能导致内存碎片化。

(对应于 jnp.zeros_like(x) 的零数组构造的 broadcast 被分阶段执行,因为 JAX 对来自 google/jax#1668 的非常简单的表达式是懒惰的。在 omnistaging 之后,我们可以移除那个懒惰的子语言并简化 JAX 的内部实现。)

创建 mask 未被分阶段执行的原因是,在 omnistaging 之前,jit 基于数据依赖性进行操作。也就是说,jit 只会分阶段执行那些在函数中与参数有数据依赖关系的操作。控制流原语和 pmap 的行为类似。在 select_tril 的情况下,构造常量 mask 的操作与参数 x 没有数据依赖关系,因此它们不会被分阶段执行;只有 lax.select 调用具有数据依赖关系。

通过 omnistaging,在 jit 转换函数的动态上下文中,所有 jax.numpy 调用都被分阶段输出到 XLA。也就是说,在 omnistaging 之后,XLA 看到的 select_tril 计算是

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

当 omnistaging 开启时可能会出现什么问题?#

由于在 jitpmap 的动态上下文中,所有 jax.numpy 操作从 Python 转移到 XLA 进行分阶段处理,一些之前可以运行的代码可能会开始引发明显的错误。如下所述,这些行为在 omnistaging 之前就已经存在问题,但 omnistaging 将其变成了硬错误。

使用 jax.numpy 进行形状计算#

示例#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

错误信息#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

解释#

在使用 omnistaging 时,我们不能像上面使用 jnp.prod 那样使用 jax.numpy 进行形状计算,因为在 jit 函数的动态上下文中,这些操作将被作为执行时计算的值从 Python 中移出,然而我们需要它们在编译时(因此也是跟踪时)是常量。

在全阶段之前,这段代码不会引发错误,但它是一个常见的性能问题:jnp.prod 计算会在跟踪时在设备上执行,这意味着额外的编译、传输、同步、分配,以及潜在的内存碎片化。

解决方案#

解决方案很简单,就是使用原始的 numpy 进行这些形状计算。我们不仅避免了错误,还保持了计算在主机上进行(并且开销更低)。

这个问题在代码中很常见,因此我们尝试让错误信息特别清晰。除了显示抽象跟踪器值导致问题的堆栈跟踪(完整堆栈跟踪中的 jnp.reshape 行,位于 omni.py:10),我们还解释了为什么这个值首先变成了跟踪器,通过指向导致它成为抽象跟踪器的上游原始操作(jnp.prod 中的 reduce_prod,位于 omni.py:9)以及跟踪器所属的 jit 装饰函数(ex1,位于 omni.py:6)。

副作用#

示例#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

最后一次调用虽然有重复的随机性但没有硬错误,因为我们没有重新执行 Python。但如果我们查看 key,我们会看到一个转义的追踪器 当 omnistaging 开启时

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

在全阶段之前,random.split 调用不会被阶段化,因此我们不会得到一个逃逸的追踪器。代码仍然会有缺陷,因为jit函数不会重现原始函数的语义(由于重复使用相同的PRNG键),最终是由于副作用。

启用 omnistaging 后,如果我们再次访问 key,我们将得到一个逃逸的 tracer 错误:

random.normal(key, ())

错误信息#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

解释#

我们发现的第二大类 omnistaging 问题与副作用代码有关。这些代码已经通过转换有副作用的函数而使 JAX 保修失效,但由于 omnistaging 之前的“跟踪时恒定折叠”行为,一些有副作用的函数仍然可以正确运行。Omnistaging 捕获了更多这些错误。

解决方案#

解决方案是识别依赖于副作用的 JAX 转换函数,并重写它们以避免副作用。

基于XLA优化的微小数值差异#

由于 omnistaging 将更多计算转移到 XLA 进行,而不是在跟踪时执行部分计算,这可能会导致浮点操作的重新排序。因此,我们观察到数值行为发生了变化,导致在启用 omnistaging 时,容差过紧的测试失败。

依赖于JAX内部API的更改#

Omnistaging 对 JAX 的核心代码进行了一些重大修订,包括删除或更改内部函数。任何依赖于这些内部 JAX API 的代码在启用 omnistaging 时都可能中断,无论是构建错误(来自 pytype)还是运行时错误。

触发 XLA 编译时错误#

由于全阶段涉及将更多代码分阶段到XLA,我们发现它在某些后端触发了现有的XLA编译时错误。对于这些错误,最好的做法是报告它们,以便我们与XLA团队合作进行修复。