jax.random
模块#
用于伪随机数生成的工具。
The jax.random
包提供了许多用于确定性生成伪随机数序列的例程。
基本用法#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG 键#
与用户可能习惯的 有状态 伪随机数生成器 (PRNGs) 不同,JAX 的随机函数都需要显式传递一个 PRNG 状态作为第一个参数。随机状态由我们称之为 键 的特殊数组元素类型描述,通常由 jax.random.key()
函数生成:
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
此密钥随后可用于 JAX 的任何随机数生成例程中:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
请注意,使用一个键并不会修改它,因此重复使用相同的键将导致相同的结果:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
如果你需要一个新的随机数,你可以使用 jax.random.split()
来生成新的子键:
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
备注
类型化的键数组,如上面的 key<fry>
,是在 JAX v0.4.16 中引入的。在此之前,键通常以 uint32
数组表示,其最后一个维度表示键的位级表示。
两种形式的密钥数组仍然可以使用 jax.random
模块创建和使用。新样式的类型化密钥数组通过 jax.random.key()
创建。旧样式的 uint32
密钥数组通过 jax.random.PRNGKey()
创建。
要在两者之间转换,请使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。当与 JAX 之外的系统交互(例如,将数组导出为可序列化格式),或当将密钥传递给假设旧格式的基于 JAX 的库时,可能需要旧密钥格式。
否则,建议使用键入的键。与键入的键相比,传统键的注意事项包括:
它们有一个额外的尾随维度。
它们具有数值类型(
uint32
),允许进行通常不适用于键的操作,例如整数算术。它们不携带关于 RNG 实现的信息。当遗留密钥传递给
jax.random
函数时,一个全局配置设置决定了 RNG 实现(见下面的“高级 RNG 配置”)。
要了解更多关于此升级以及关键类型的设计,请参阅 JEP 9263。
高级#
设计和背景#
TLDR: JAX PRNG = Threefry 计数器 PRNG + 一个面向数组的函数式 分裂模型
更多详情请参见 docs/jep/263-prng.md。
总结来说,除了其他要求,JAX PRNG 旨在:
确保可重复性,
在向量化(生成数组值)和多副本、多核计算方面都能很好地并行化。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置#
JAX 提供了几种 PRNG 实现。可以通过 jax.random.key
的可选 impl
关键字参数选择特定的实现。如果没有向 key
构造函数传递 impl
选项,则实现由全局的 jax_default_prng_impl
配置标志决定。可用实现的名称字符串有:
"threefry2x32"
(默认): 一个基于计数器的伪随机数生成器,基于Threefry哈希函数的一个变种,如 Salmon等人在2011年的这篇论文 中所述。"rbg"
和"unsafe_rbg"
(实验性): 基于 XLA 的随机位生成器 (RBG) 算法 构建的 PRNGs。"rbg"
使用 XLA RBG 进行随机数生成,而对于密钥派生(如在jax.random.split
和jax.random.fold_in
中),它使用与"threefry2x32"
相同的方法。"unsafe_rbg"
使用 XLA RBG 进行生成以及密钥派生。
这些实验方案生成的随机数尚未经过经验随机性测试(例如 BigCrush)。
在
"unsafe_rbg"
中的密钥派生也没有经过实证测试。名称强调了“不安全”,因为密钥派生质量和生成质量尚未被充分理解。此外,
"rbg"
和"unsafe_rbg"
在jax.vmap
下表现异常。当对一批键进行随机函数映射时,其输出值可能与对相同键的真实映射不同。相反,在vmap
下,整个输出随机数批次仅从输入键批次中的第一个键生成。例如,如果keys
是一个包含8个键的向量,那么jax.vmap(jax.random.normal)(keys)
等于jax.random.normal(keys[0], shape=(8,))
。这种特殊性反映了XLA RBG有限批处理支持的变通方法。
使用默认RNG的替代方案的原因包括:
对于TPUs,编译可能会很慢。
在TPU上执行相对较慢。
自动分区:
为了使 jax.jit
能够有效地自动分区生成分片随机数数组(或键数组)的函数,所有 PRNG 实现都需要额外的标志:
对于
"threefry2x32"
和"rbg"
密钥派生,设置jax_threefry_partitionable=True
。对于
"unsafe_rbg"
和"rbg"
随机生成,设置 XLA 标志--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
可以使用 XLA_FLAGS
环境变量来设置 XLA 标志,例如 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
更多关于 jax_threefry_partitionable
的信息,请参见 https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
摘要:
属性 |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
不安全的rbg** |
---|---|---|---|---|---|---|
在TPU上最快 |
✅ |
✅ |
✅ |
✅ |
||
高效可分片(使用 pjit) |
✅ |
✅ |
✅ |
|||
在分片中相同 |
✅ |
✅ |
✅ |
✅ |
||
在 CPU/GPU/TPU 上相同 |
✅ |
✅ |
||||
精确 |
✅ |
✅ |
(*): 设置 jax_threefry_partitionable=1
(**): 设置 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
API 参考#
密钥创建与操作#
|
给定一个整数种子,创建一个伪随机数生成器(PRNG)键。 |
|
恢复伪随机数生成器(PRNG)密钥数组下的关键数据位。 |
|
将一组关键数据位包装成一个PRNG密钥数组。 |
|
将数据折叠到 PRNG 密钥中以形成新的 PRNG 密钥。 |
|
通过添加一个前导轴,将一个 PRNG 键拆分为 num 个新键。 |
|
克隆一个密钥以供重复使用 |
|
给定一个整数种子,创建一个遗留的PRNG密钥。 |
随机采样器#
|
从单位 Lp 球中均匀采样。 |
|
使用给定的形状和均值生成伯努利随机值。 |
|
使用给定的形状和浮点数据类型生成 Beta 随机值。 |
|
使用给定的形状和浮点数据类型生成二项随机值的示例。 |
|
以无符号整数形式表示的样本均匀位。 |
|
从分类分布中采样随机值。 |
|
使用给定的形状和浮点数据类型生成样本柯西随机值。 |
|
使用给定的形状和浮点数据类型生成卡方随机值的示例。 |
|
从给定的数组中生成一个随机样本。 |
|
使用给定的形状和浮点数据类型生成狄利克雷随机值。 |
|
从双边麦克斯韦分布中采样。 |
|
使用给定的形状和浮点数据类型生成指数随机值。 |
|
具有给定形状和浮点数据类型的样本 F 分布随机值。 |
|
给定形状和浮点数据类型的 Gamma 随机值示例。 |
|
从广义正态分布中抽取样本。 |
|
使用给定的形状和浮点数据类型生成几何随机值的示例。 |
|
使用给定的形状和浮点数据类型生成Gumbel随机值。 |
|
使用给定的形状和浮点数据类型生成拉普拉斯随机值的示例。 |
|
给定形状和浮点数据类型的示例对数伽马随机值。 |
|
使用给定的形状和浮点数据类型生成样本逻辑随机值。 |
|
使用给定的形状和浮点数据类型生成对数正态随机值的示例。 |
|
从单边麦克斯韦分布中采样。 |
|
使用给定的均值和协方差生成多元正态随机值的示例。 |
|
生成具有给定形状和浮点数据类型的标准正态随机值。 |
|
从正交群 O(n) 中均匀采样。 |
|
使用给定的形状和浮点数据类型生成帕累托随机值。 |
|
返回一个随机排列的数组或范围。 |
|
使用给定的形状和整数数据类型生成泊松随机值。 |
|
从 Rademacher 分布中采样。 |
|
在给定的形状/数据类型下,生成 [minval, maxval) 范围内的均匀随机值。 |
|
具有给定形状和浮点数据类型的示例瑞利随机值。 |
|
使用给定的形状和浮点数据类型生成学生 t 分布的随机值。 |
|
使用给定的形状和浮点数据类型生成三角形随机值的示例。 |
|
具有给定形状和数据类型的截断标准正态随机值示例。 |
|
在给定的形状/数据类型下,生成 [minval, maxval) 范围内的均匀随机值。 |
|
使用给定的形状和浮点数据类型生成 Wald 随机值。 |
|
从威布尔分布中抽取样本。 |