JEP 9263: 类型化键 & 可插拔随机数生成器#

Jake VanderPlas, Roy Frostig

2023年8月

概述#

今后,JAX中的RNG键将更加类型安全和可定制。不再通过长度为2的 uint32 数组来表示单个PRNG键,而是将其表示为一个具有特殊RNG数据类型的标量数组,该数据类型满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)

目前,旧样式的 RNG 键仍可以使用 jax.random.PRNGKey() 创建:

>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')

从现在开始,可以使用 jax.random.key() 创建新样式的 RNG 键:

>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>

这个(标量形状的)数组与任何其他 JAX 数组的行为相同,只是它的元素类型是一个键(及其关联的元数据)。我们也可以制作非标量键数组,例如通过将 jax.vmap() 应用于 jax.random.key()

>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
 [0 1]
 [0 2]
 [0 3]]
>>> key_arr.shape
(4,)

除了切换到一个新的构造函数外,大多数与PRNG相关的代码应该继续按预期工作。你可以继续像以前一样在jax.random API中使用密钥;例如:

# split
new_key, subkey = jax.random.split(key)

# random number generation
data = jax.random.uniform(key, shape=(5,))

然而,并非所有数值操作都适用于键数组。它们现在会故意引发错误:

>>> key = key + 1  
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.

如果由于某些原因你需要恢复底层缓冲区(旧式键),你可以使用 jax.random.key_data() 来实现。

>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)

对于旧式键,key_data() 是一个恒等操作。

这对用户意味着什么?#

对于 JAX 用户,此更改目前不需要任何代码更改,但我们希望您会发现升级是值得的,并切换到使用类型化键。要尝试这一点,请将 jax.random.PRNGKey() 的使用替换为 jax.random.key()。这可能会在您的代码中引入一些破坏,这些破坏属于以下几类之一:

  • 如果你的代码对键执行不安全/不支持的操作(例如索引、算术、转置等;请参见下面的类型安全部分),此更改将会捕捉到它。你可以更新你的代码以避免这些不支持的操作,或者使用 jax.random.key_data()jax.random.wrap_key_data() 以不安全的方式操作原始键缓冲区。

  • 如果你的代码包含关于 key.shape 的显式逻辑,你可能需要更新此逻辑以考虑尾随键缓冲区维度不再是形状的显式部分这一事实。

  • 如果你的代码包含关于 key.dtype 的显式逻辑,你需要将其升级以使用新的公共 API 来推理 RNG dtypes,例如 dtypes.issubdtype(dtype, dtypes.prng_key)

  • 如果你调用一个基于 JAX 的库,而该库尚未处理类型化的 PRNG 键,你可以暂时使用 raw_key = jax.random.key_data(key) 来恢复原始缓冲区,但请保持一个 TODO,以便在下游库支持类型化 RNG 键时移除此代码。

在未来某个时候,我们计划弃用 jax.random.PRNGKey() 并要求使用 jax.random.key()

检测新样式类型键#

要检查一个对象是否是新样式类型的PRNG键,可以使用 jax.dtypes.issubdtypejax.numpy.issubdtype

>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False

PRNG 键的类型注解#

推荐用于旧式和新式 PRNG 键的类型注解是 jax.Array。PRNG 键通过其 dtype 与其他数组区分开来,并且目前无法在类型注解中指定 JAX 数组的 dtype。以前可以使用 jax.random.KeyArrayjax.random.PRNGKeyArray 作为类型注解,但这些在类型检查下始终被别名为 Any,因此 jax.Array 具有更高的特异性。

注意:jax.random.KeyArrayjax.random.PRNGKeyArray 在 JAX 版本 0.4.16 中已被弃用,并在 JAX 版本 0.4.24 中被移除。

JAX 库作者的注意事项#

如果你维护一个基于 JAX 的库,你的用户也是 JAX 用户。要知道,JAX 将继续支持 jax.random 中的“原始”旧式键,因此调用者可能期望它们在任何地方都被接受。如果你更倾向于在你的库中要求新式类型键,那么你可能希望按照以下方式进行检查:

from jax import dtypes

def ensure_typed_key_array(key: Array) -> Array:
  if dtypes.issubdtype(key.dtype, dtypes.prng_key):
    return key
  else:
    raise TypeError("New-style typed JAX PRNG keys required")

动机#

这一变化的两个主要动机是可定制性和安全性。

自定义 PRNG 实现#

JAX 目前使用单一的、全局配置的 PRNG 算法。PRNG 密钥是一个无符号 32 位整数的向量,jax.random API 消耗它来生成伪随机流。任何更高秩的 uint32 数组都被解释为这样的密钥缓冲区数组,其中尾随维度表示密钥。

随着我们引入了替代的PRNG实现,这种设计的缺点变得更加明显,这些实现必须通过设置全局或局部配置标志来选择。不同的PRNG实现有不同大小的密钥缓冲区,以及不同的生成随机比特的算法。通过全局标志来确定这种行为容易出错,尤其是在整个进程中使用多个密钥实现时。

我们的新方法是将实现作为 PRNG 键类型的一部分,即键数组的元素类型。使用新的键 API,以下是在默认的 threefry2x32 实现(在纯 Python 中实现并通过 JAX 编译)和非默认的 rbg 实现(对应于单个 XLA 随机位生成操作)下生成伪随机值的示例:

>>> key = jax.random.key(0, impl='threefry2x32')  # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)

>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)

安全PRNG密钥使用#

PRNG 密钥实际上只支持少数几个操作,即密钥派生(例如拆分)和随机数生成。PRNG 旨在生成独立的伪随机数,前提是密钥被正确拆分并且每个密钥只被使用一次。

以其他方式操作或消耗关键数据的代码通常表明存在意外的错误,而将关键数组表示为原始的 uint32 缓冲区则允许了沿着这些线的误用。以下是我们遇到的一些误用示例:

键缓冲索引#

访问底层整数缓冲区使得尝试以非标准方式派生密钥变得容易,有时会带来意想不到的糟糕后果:

# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1])  # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)

如果这个键是用 random.key(999) 生成的新样式类型化键,那么对键缓冲区的索引将会出错。

关键算术#

键算术是另一种同样危险的方式,用于从其他键派生键。通过直接操作键数据来避免 jax.random.split()jax.random.fold_in() 的方式派生键,会产生一批键——取决于PRNG的实现——这些键可能在批次内生成相关的随机数:

# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)

使用 random.key(0) 创建的新样式类型化键通过禁止对键进行算术操作来解决这个问题。

无意中交换了关键缓冲区#

使用“原始”旧式键数组时,很容易意外地交换批次(前导)维度和键缓冲区(尾随)维度。这可能会导致产生相关伪随机性的键。我们随着时间观察到的模式可以归结为这一点:

# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)

这里的错误很微妙。通过映射 in_axes=1,这段代码通过将批次中每个键缓冲区的一个元素组合起来创建新键。生成的键彼此不同,但以一种非标准的方式“派生”。再次强调,PRNG 的设计和测试并不是为了从这样的键批次中产生独立的随机流。

使用 random.key(0) 创建的新样式类型化键通过隐藏单个键的缓冲区表示来解决这个问题,而是将键视为键数组的非透明元素。键数组没有尾随的“缓冲区”维度来进行索引、转置或映射。

密钥重用#

与基于状态的PRNG API(如numpy.random)不同,JAX的函数式PRNG在使用后不会隐式更新密钥。

# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,))  # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))

我们正在积极开发工具来检测和防止意外的密钥重用。这仍然是进行中的工作,但它依赖于类型化的密钥数组。现在升级到类型化密钥为我们引入了这些安全功能,因为我们在构建它们。

类型化PRNG键的设计#

类型化的 PRNG 键在 JAX 中作为扩展 dtypes 的实例实现,其中新的 PRNG dtypes 是子 dtypes。

扩展的数据类型#

从用户的角度来看,扩展的 dtype dt 具有以下用户可见的属性:

  • jax.dtypes.issubdtype(dt, jax.dtypes.extended) 返回 True:这是用于检测一个 dtype 是否为扩展 dtype 的公共 API。

  • 它有一个类级别的属性 dt.type,返回 numpy.generic 层次结构中的一个类型类。这与 np.dtype('int32').type 返回 numpy.int32 类似,后者不是一个 dtype,而是一个标量类型,并且是 numpy.generic 的子类。

  • 与 numpy 标量类型不同,我们不允许实例化 dt.type 标量对象:这与 JAX 决定将标量值表示为零维数组的决策一致。

从非公开的实现角度来看,扩展的 dtype 具有以下属性:

  • 它的类型是私有基类 jax._src.dtypes.ExtendedDtype 的子类,该非公开基类用于扩展数据类型。ExtendedDtype 的实例类似于 np.dtype 的实例,例如 np.dtype('int32')

  • 它有一个私有的 _rules 属性,允许 dtype 定义在特定操作下的行为方式。例如,jax.lax.full(shape, fill_value, dtype)dtype 是扩展 dtype 时,会委托给 dtype._rules.full(shape, fill_value, dtype)

为什么要在PRNG之外普遍引入扩展数据类型?我们在内部其他地方也重用了这种相同的扩展数据类型机制。例如,jax._src.core.bint对象,一个用于动态形状实验工作的有界整数类型,是另一种扩展数据类型。在最近的JAX版本中,它满足上述属性(参见jax/_src/core.py#L1789-L1802)。

PRNG 数据类型#

PRNG dtypes 被定义为扩展 dtypes 的一个特例。具体来说,这一变化引入了一个新的公共标量类型类 jax.dtypes.prng_key,它具有以下属性:

>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True

PRNG 键数组随后具有以下属性的 dtype:

>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True

除了针对扩展数据类型的一般情况中提到的 key.dtype._rules 之外,PRNG 数据类型还定义了 key.dtype._impl,其中包含了定义 PRNG 实现的元数据。PRNG 实现目前由非公开的 jax._src.prng.PRNGImpl 类定义。目前,PRNGImpl 并不意味着是一个公开的 API,但我们可能会很快重新审视这一点,以允许完全自定义的 PRNG 实现。

进展#

以下是一个非全面的实现上述设计的关键拉取请求列表。主要跟踪问题是 #9263

  • 通过 PRNGImpl 实现可插拔的 PRNG:#6899

  • 实现 PRNGKeyArray,不带 dtype: #11952

  • PRNGKeyArray 添加一个带有 _rules 属性的“自定义元素” dtype 属性:#12167

  • 将“自定义元素类型”重命名为“不透明数据类型”:#12170

  • 重构 bint 以使用不透明 dtype 基础设施:#12707

  • 添加 jax.random.key 以直接创建类型化的键:#16086

  • keyPRNGKey 添加 impl 参数:#16589

  • 将“不透明数据类型”重命名为“扩展数据类型”,并定义 jax.dtypes.extended#16824

  • 介绍 jax.dtypes.prng_key 并将 PRNG 数据类型与扩展数据类型统一:#16781

  • 添加一个 jax_legacy_prng_key 标志,以在使用旧版(原始)PRNG 键时支持警告或错误:#17225