JAX PRNG 设计#

我们想要一个伪随机数生成器设计,

  1. 表达力强 在于它使用方便,并且不会限制用户编写具有完全所需行为的数值程序的能力。

  2. 以与后端无关的方式启用 可重复 的程序执行,

  3. 具有不受 @jit 编译边界和设备后端影响的语义,

  4. 启用使用 SIMD 硬件 生成数组值的向量化

  5. 可并行化 因为它不会在随机函数调用之间增加顺序约束,而这些调用本来就没有数据依赖关系。

  6. 可扩展至 多副本、多核和分布式计算

  7. 与 JAX 和 XLA 的语义 和设计理念相契合(这些理念最终是由其他实际考虑所驱动的)。

作为这些的推论,我们认为设计应该是功能性的。另一个推论是,至少在当前硬件限制下,我们将在软件中进行PRNG。

TLDR JAX PRNG = Threefry 计数器 PRNG + 一个函数式面向数组的 分割模型

内容#

三种编程模型和示例程序#

这是一个类似于在Numpy程序中常用的有状态全局PRNG的玩具示例:

def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

为了在这里实现可重复性,我们需要控制 bar()baz() 的评估顺序,即使它们之间没有显式的数据依赖关系。这种源自可重复性(#2)的排序要求违反了并行性(#5),并且不符合 JAX 或 XLA 的功能语义(#6),在功能语义中,子表达式可以按任意顺序评估。即使我们不要求可重复性,从而允许任何评估顺序,由于需要更新共享状态,跨调用的并行化(#5)仍然会变得困难。此外,因为相同的 PRNG 状态需要在 Python 和任何编译代码中访问和维护,这种模型可能会导致实现编译不变性(#3)和扩展到多个副本(#6)的工程挑战。最后,表达能力受限(#1),因为没有方法让 foo() 调用 bar()baz() 而不影响其自身的(隐式)PRNG 状态。

模型是否支持向量化(#4)取决于一些额外的细节。在Numpy中,PRNG向量化受限于一个顺序等价保证

In [1]: rng = np.random.RandomState(0)

In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])

In [3]: rng = np.random.RandomState(0)

In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])

为了在生成数组的原始PRNG函数调用中允许向量化(#4)(例如,使用形状参数调用rand()),我们放弃了这种顺序等价保证。这种向量化可以通过本节讨论的三种编程模型中的任何一种来支持,尽管它促使我们按照下一节中描述的基于计数器的PRNG来实现。

有状态的PRNG用户编程模型并不乐观。以下是一个功能模型的例子,但缺少了我们称之为“分裂”的关键成分:

def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

此模型明确地将PRNG状态传递给所有生成随机值的函数(无论是基本函数还是非基本函数):也就是说,每个随机函数都必须同时接受和返回状态。现在,foo()中对baz()的调用和对bar()的调用之间存在明确的数据依赖关系,因此数据流(以及顺序)变得明确,并且符合JAX现有的语义(#7),这与之前的模型不同。这种明确的传递还可以使语义在编译边界上保持不变(#3)。

显式线程化对程序员来说是不方便的。但更糟糕的是,它实际上并没有提高表达能力(#1):foo() 在保持其自身 PRNG 状态的同时,仍然无法调用 bar() 或 baz()。在不了解其调用者或所调用子程序的情况下,函数必须在各处防御性地传递和返回 rng 状态。此外,它也没有改善并行化(#5)或扩展到多个副本(#6)的前景,因为一切仍然是顺序的,即使是在函数式编程意义上明确了顺序。

简而言之,通过显式地线程化状态来使代码功能化,不足以实现我们的表达性(#1)和性能(#5, #6)目标。

前述模型的关键问题在于序列化程度过高。为了减少顺序依赖,我们使用 功能性 可分割 PRNGs。分割是一种机制,用于将一个新的PRNG状态‘分叉’成两个PRNG状态,同时保持通常所需的PRNG特性(两个新的流在计算上是可并行的,并且产生独立的随机值,即它们表现得像 多流)。

def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

需要注意的几点:

  1. 对 bar() 和 baz() 的调用之间没有顺序依赖性,它们可以按任意顺序评估而不影响结果的值,这解决了剩余的性能目标 (#5, #6)。

  2. 函数不需要返回PRNG的更新版本,并且可以直接调用随机子程序而不影响现有的PRNG状态,从而提高了从其他函数模型中获得的表达能力(#1)。

这个例子没有展示出来,但由于选择(2)的结果,推进PRNG状态的唯一方法是调用split()。也就是说,我们有两种方法来实现(1),它们的不同之处在于是否让用户程序显式调用split(),如上面的例子所示,还是让用户程序显式处理线程。我们倾向于前者,即显式分割的版本,因为我们可以在其基础上轻松实现显式线程处理的版本。

设计#

我们可以使用 基于计数器的伪随机数生成器(PRNG) 设计,特别是 Threefry 哈希函数,如 Parallel random numbers: as easy as 1, 2, 3 中所述。我们使用计数器来实现高效的向量化:对于给定的键,我们可以通过将哈希函数映射到整数范围 [k + 1, …, k + sample_size] 上来以向量化方式生成值数组。我们使用键和哈希函数来实现 可分割的伪随机数生成器(PRNGs):也就是说,分割是一种从现有键生成两个新键的方法。

type Sample = Int256
type Key = Sample  -- important identification for splitting
type Count = Int32

hash :: Key -> Count -> Int256  -- output type equal to Key and Sample

split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)

draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]

令人惊讶的是,抽取样本与分割非常相似!关键在于输出类型的差异(尽管类型已被标识):在一种情况下,该值用于形成感兴趣的随机样本(例如,将随机位转换为表示随机正态的浮点数),而在另一种情况下,该值用作进一步哈希的键。

哈希函数参数中的不对称性,类型为 Key 和 Count,在于后者通过任意量推进是微不足道的且计算成本低廉,因为我们只需增加整数值,而前者只能通过哈希推进。这就是我们使用 count 参数进行矢量化的原因。

更现实的示例用户程序#

以下是主机上训练循环可能的样子,当步骤需要PRNG时(可能用于dropout或VAE训练):

rng = lax.rng.new_rng()
for i in xrange(num_steps):
  rng, rng_input = lax.rng.split(rng)
  params = compiled_update(rng_input, params, next(batches))

请注意,我们正在让用户明确地分割 rng,但实际上 rng 根本不需要从代码中返回。

以下是如何使用这个PRNG模型与stax神经网络构建库来实现dropout的方法:

def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

这里的 rng 值只是用于哈希的键,而不是一个特殊对象。rng 参数被传递给每个 apply_fun,因此它需要在串行和并行组合器中进行拆分处理:

def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

这里我们使用了一个简单的扩展版本的 split,它可以生成多个副本。

权衡与替代方案#

  1. 我们没有利用任何设备硬件的PRNG

    • 我们目前对所有后端硬件PRNG的状态控制还不够充分。

    • 即使我们这样做了,它也会依赖于后端,并且我们可能需要在随机调用之间引入顺序依赖,以确保确定性的顺序和因此的可重复性。

    • 我们不知道有任何工作负载会使软件伪随机数生成器成为瓶颈。

    • 我们可以考虑提供一个额外的API,允许访问硬件PRNG,供那些愿意放弃其他需求(如严格可重复性)的用户使用。

  2. 我们放弃了顺序等价保证,即在一次调用中创建一个随机数组产生的值与每次创建一个随机元素的扁平数组相同。

    • 此属性可能与向量化(高优先级)不兼容。

    • 我们不知道有任何用户或示例认为这个属性是重要的。

    • 用户可以在该API之上编写一个层来提供这种保证。

  3. 我们不能完全遵循 numpy.random API。