分布式数组和自动并行化#
本教程讨论了通过 jax.Array
实现并行性,这是JAX v0.4.1及更高版本中可用的统一数组对象模型。
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
⚠️ 警告:此笔记本需要8个设备才能运行。
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
介绍和快速示例#
通过阅读本教程笔记本,您将了解 jax.Array
,这是一种统一的数据类型,用于表示数组,即使在物理存储跨多个设备的情况下。您还将学习如何使用 jax.Array
结合 jax.jit
来提供基于编译器的自动并行化。
在我们逐步思考之前,先来看一个快速示例。
首先,我们将创建一个跨多个设备分片的 jax.Array
:
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
# 创建一个分片对象以将值分布到多个设备上:
mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),
axis_names=('x', 'y'))
# 创建一个随机值数组:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# 并使用 jax.device_put 将其分布到各个设备上:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
接下来,我们将对其进行计算,并可视化结果值如何分布在多个设备上:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
jnp.sin
应用的评估自动在存储输入值(和输出值)的设备上进行并行化:
# `x` 存在于单个设备上
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
# `y` 被分片存储在8个设备上。
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
现在让我们更详细地看一下这些内容!
分片
描述了数组值在设备之间的内存布局#
分片基础及 NamedSharding
子类#
为了在多个设备上并行计算,我们首先必须将输入数据分布在多个设备上。
在JAX中,Sharding
对象描述了分布式内存布局。它们可以与jax.device_put
一起使用,以生成具有分布式布局的值。
例如,这里有一个具有单设备Sharding
的值:
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0 │
│ │
│ │
│ │
│ │
└───────────────────────┘
在这里,我们使用 jax.debug.visualize_array_sharding
函数来显示值 x
存储在内存中的位置。所有的 x
都存储在一个设备上,因此这个可视化结果看起来非常无聊!
但是我们可以通过使用 jax.device_put
和一个 Sharding
对象将 x
分片到多个设备上。首先,我们使用 mesh_utils.create_device_mesh
创建一个包含设备的 numpy.ndarray
,该函数会考虑硬件拓扑结构来确定 Device
的顺序:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
我们可以定义一个辅助函数来简化事情:
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
在这里,我们使用 P('a', 'b')
来表示 x
的第一和第二个轴应在设备网格轴 'a'
和 'b'
上进行分片。我们可以很方便地切换到 P('b', 'a')
来在不同的设备上对 x
的轴进行分片:
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
# 这里的 `None` 表示 `x` 在其第二个维度上未进行分片,
# and since the Mesh axis name 'b' is not mentioned, shards are
# 在其上进行了复制。
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
在这里,由于 P('a', None)
没有提到 Mesh
轴名称 'b'
,我们在轴 'b'
上得到了复制。这里的 None
只是作为一个占位符,用来与值 x
的第二个轴对齐,并不表示在任何网格轴上的分片。(作为简写,可以省略尾部的 None
,因此 P('a', None)
和 P('a')
是相同的。但明确表示并没有坏处!)
要仅在 x
的第二个轴上进行分片,我们可以在 PartitionSpec
中使用 None
占位符:
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
对于固定的网格,我们甚至可以将一个逻辑轴 x
划分到多个设备网格轴上:
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐ │ TPU 0 │ ├───────────────────────┤ │ TPU 1 │ ├───────────────────────┤ │ TPU 2 │ ├───────────────────────┤ │ TPU 3 │ ├───────────────────────┤ │ TPU 6 │ ├───────────────────────┤ │ TPU 7 │ ├───────────────────────┤ │ TPU 4 │ ├───────────────────────┤ │ TPU 5 │ └───────────────────────┘
使用 NamedSharding
使得定义设备网格变得简单,只需一次定义并给其轴命名,然后在每个 device_put
的 PartitionSpec
中根据需要引用这些名称。
计算在数据分片后进行,并自动并行化#
使用分片输入数据,编译器可以为我们提供并行计算。特别是,使用 jax.jit
装饰的函数可以在不将数据复制到单个设备上的情况下,对分片数组进行操作。相反,计算遵循分片:基于输入数据的分片,编译器决定中间值和输出值的分片,并对它们的求值进行并行化,必要时还会插入通信操作。
例如,最简单的计算是逐元素计算:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
对于元素级操作 jnp.sin
,编译器选择了与输入相同的输出分片。此外,编译器自动将计算并行化,因此每个设备从其输入分片并行计算其输出分片。
换句话说,尽管我们将 jnp.sin
的计算写成好像是单台机器要执行它,编译器实际上为我们分割了计算并在多个设备上执行。
我们也可以对不仅仅是元素级操作进行相同处理。考虑一个具有分片输入的矩阵乘法:
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里编译器选择了输出分片,以便最大程度地并行计算:在不需要通信的情况下,每个设备已经拥有计算其输出分片所需的输入分片。
我们如何能确保它实际上是并行运行的?我们可以做一个简单的计时实验:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0 │
│ │
│ │
│ │
│ │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
即使复制一个分片的 Array
,产生的结果也具有输入的分片:
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
因此,计算遵循数据放置:当我们使用 jax.device_put
显式地对数据进行分片,并对这些数据应用函数时,编译器会尝试并行化计算并决定输出分片。这种对分片数据的策略是JAX 跟随显式设备放置的策略的一个推广。
当显式分片不一致时,JAX 报错#
但如果计算的两个参数被显式放置在不同的设备集合上,或者设备顺序不兼容呢? 在这些模糊情况下,会引发错误:
import textwrap
from termcolor import colored
def print_exception(e):
name = colored(f'{type(e).__name__}', 'red', force_color=True)
print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU
我们说通过 jax.device_put
明确放置或分片的数组是 提交的,因此不会被自动移动。有关更多信息,请参见 设备放置常见问题。
当数组 没有 通过 jax.device_put
明确放置或分片时,它们会被_未提交_地放置在默认设备上。与提交数组不同,未提交数组可以自动移动和重新分片:也就是说,未提交数组可以作为计算的参数,即使其他参数被明确放置在不同设备上。
例如,jnp.zeros
、jnp.arange
和 jnp.array
的输出是未提交的:
y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!
限制 jit
编译代码中中间结果的分片#
虽然编译器会试图决定一个函数的中间值和输出应该如何进行分片,但我们也可以使用 jax.lax.with_sharding_constraint
给它提供一些提示。使用 jax.lax.with_sharding_constraint
的方式与 jax.device_put
类似,只是我们在分阶段的(即被 jit
装饰的)函数内部使用它。
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3,4,5,6,7 │ │ │ │ │ │ │ │ │ └───────────────────────┘
通过添加 with_sharding_constraint
,我们限制了输出的分片。除了尊重特定中间值上的注释外,编译器还将使用注释来决定其他值的分片。
对计算结果进行注释通常是一个良好的实践,例如基于值最终如何被使用。
示例:神经网络#
⚠️ 警告:以下内容旨在演示使用 jax.Array
的自动分片传播,但可能未能反映真实示例的最佳实践。 例如,真实示例可能需要更多地使用 with_sharding_constraint
。
我们可以使用 jax.device_put
和 jax.jit
的计算跟随分片特性来并行化神经网络中的计算。以下是一些简单的示例,基于这个基本的神经网络:
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init_model(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8路批数据并行性#
mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
4路批数据并行和2路模型张量并行#
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)
W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))
W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)
W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
print(loss_jit(params, batch))
10.760109
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
锐利的部分#
生成随机数#
JAX 提供了一个功能性、确定性的 随机数生成器。它是 jax.random
模块 中各种采样函数的基础,例如 jax.random.uniform
。
JAX 的随机数是通过基于计数器的 PRNG 生成的,因此原则上,随机数生成应该是对计数器值的一个纯映射。纯映射在原则上是一个微不足道的可分割操作。它不应需要跨设备的通信,也不应在设备之间有任何冗余计算。
然而,现有的稳定 RNG 实现由于历史原因并不是自动可分割的。
考虑以下示例,在此示例中,一个函数生成随机均匀数,并将其逐元素添加到输入中:
@jax.jit
def f(key, x):
numbers = jax.random.uniform(key, x.shape)
return x + numbers
key = jax.random.key(42)
mesh = Mesh(jax.devices(), 'x')
x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)
在分区输入上,函数 f
产生的输出也是分区的:
jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
但如果我们检查在这个分区输入上对 f
的编译计算,我们会发现确实涉及了一些通信:
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True
一种解决方法是通过配置 JAX 的实验性升级标志 jax_threefry_partitionable
。当该标志开启时,”集体置换”操作将不再出现在编译的计算中:
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False
输出仍然被分区:
jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
然而,jax_threefry_partitionable
选项有一个注意事项,即_生成的随机值可能与未设置该标志时不同_,尽管它们是由相同的随机密钥生成的:
jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()
jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))
Stable:
[ 0.72503686 1.8532515 2.983416 3.083253 4.0332246 5.4782867
6.1720605 7.6900277 8.602836 9.810046 10.861367 11.907651
12.330483 13.456195 14.808557 15.960099 16.067581 17.739723
18.335474 19.46401 20.390276 21.116539 22.858128 23.223194 ]
Partitionable:
[ 0.48870957 1.6797972 2.6162715 3.561016 4.4506445 5.585866
6.0748096 7.775133 8.698959 9.818634 10.350306 11.87282
12.925881 13.86013 14.477554 15.818481 16.711355 17.586697
18.073738 19.777622 20.404566 21.119123 22.026257 23.63918 ]
在 jax_threefry_partitionable
模式下,JAX PRNG 仍然是确定性的,但其实现是新的(并在开发中)。对于给定的密钥,生成的随机值将在特定的 JAX 版本(或 main
分支上的特定提交)中保持相同,但可能在不同版本之间有所变化。