并行编程简介#
本教程作为JAX中单程序多数据(SPMD)代码的设备并行性入门。SPMD是一种并行性技术,其中相同的计算(例如神经网络的前向传播)可以在不同的输入数据(例如,一批中的不同输入)上,在多个设备(如若干GPU或Google TPU)上并行运行。
本教程涵盖三种并行计算模式:
通过
jax.jit()
的自动并行性:编译器选择最佳计算策略(也称为“编译器掌舵”)。使用
jax.jit()
和jax.lax.with_sharding_constraint()
的半自动并行性使用
jax.experimental.shard_map.shard_map()
的完全手动并行性:shard_map
允许逐设备代码和显式通信集合
利用这些SPMD的思维方式,您可以将为单个设备编写的函数转换为可以在多个设备上并行运行的函数。
如果您在Google Colab笔记本中运行这些示例,请确保您的硬件加速器是最新的Google TPU,通过检查您的笔记本设置:运行时 > 更改运行时类型 > 硬件加速器 > TPU v2(该选项提供八个可用设备)。
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
关键概念:数据分片#
所有以下分布式计算方法的关键是数据分片的概念,它描述了数据在可用设备上的布局。
JAX如何理解数据在设备上的布局?JAX的数据类型,即jax.Array
不可变数组数据结构,表示具有跨一个或多个设备的物理存储的数组,并帮助使并行性成为JAX的核心特性。 jax.Array
对象的设计考虑了分布式数据和计算。每个jax.Array
都有一个相关的jax.sharding.Sharding
对象,描述每个全局设备所需的全局数据的哪一片。当你从头创建一个jax.Array
时,你还需要创建它的Sharding
。
在最简单的情况下,数组在单个设备上进行分片,如下所示:
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
为了更直观地展示存储布局,jax.debug
模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding()
显示数组是如何存储在单个设备的内存中的:
jax.debug.visualize_array_sharding(arr)
TPU 0
要创建一个具有非平凡分片的数组,可以为数组定义一个 jax.sharding
规范,并将其传递给 jax.device_put()
。
在这里,定义一个 NamedSharding
,该类指定了具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh
允许精确的设备放置:
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))
将这个 Sharding
对象传递给 jax.device_put()
,您可以获得一个分片数组:
arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29. 30. 31.]]
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
这里的设备编号不是按数字顺序排列的,因为网格反映了设备的基础环形拓扑结构。
1. 通过 jit
自动并行化#
一旦你有了分片的数据,进行并行计算最简单的方法就是将数据传递给一个经过 jax.jit()
编译的函数!在 JAX 中,你只需要指定你希望代码的输入和输出如何进行分区,编译器将会自动处理:1) 内部所有内容的分区;2) 设备间的通信编译。
jit
背后的 XLA 编译器包括针对多设备计算优化的启发式算法。在最简单的情况下,这些启发式算法归结为 计算跟随数据。
为了演示 JAX 中的自动并行化是如何工作的,下面是一个使用 jax.jit()
装饰的分阶段函数的示例:这是一个简单的元素级函数,每个分片的计算将在与该分片关联的设备上执行,输出也将以相同的方式进行分片:
@jax.jit
def f_elementwise(x):
return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True
随着计算变得愈加复杂,编译器决定如何最好地传播数据的分片。
在这里,您对x
的前导轴进行求和,并可视化结果值如何在多个设备上存储(使用jax.debug.visualize_array_sharding()
):
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0,6 TPU 1,7 TPU 2,4 TPU 3,5
[48. 52. 56. 60. 64. 68. 72. 76.]
结果是部分复制的:即数组的前两个元素在设备 0
和 6
上复制,第二个元素在 1
和 7
上,依此类推。
2. 带约束的半自动分片#
如果您希望在特定计算中对分片有一定的控制,JAX 提供了 with_sharding_constraint()
函数。您可以使用 jax.lax.with_sharding_constraint()
(代替 jax.device_put()
)以及 jax.jit()
来更好地控制编译器如何约束中间值和输出的分布。
例如,假设在上述 f_contract
中,您希望输出不再是部分复制的,而是希望在八个设备上完全分片:
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
[48. 52. 56. 60. 64. 68. 72. 76.]
这为您提供了一个具有特定输出分片的函数。
比较三种方法#
在我们心中牢记这些概念,让我们比较一下简单神经网络层的三种方法。
首先像这样定义你的标准函数:
@jax.jit
def layer(x, weights, bias):
return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))
layer(x, weights, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
您可以使用 jax.jit()
以分布式方式自动运行此操作,并传递适当分片的数据。
如果您以相同的方式对 x
和 weights
的首轴进行分片,那么矩阵乘法将会自动并行执行:
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
或者,您可以在函数中使用 jax.lax.with_sharding_constraint()
来自动分配未分片的输入:
@jax.jit
def layer_auto(x, weights, bias):
x = jax.lax.with_sharding_constraint(x, sharding)
weights = jax.lax.with_sharding_constraint(weights, sharding)
return layer(x, weights, bias)
layer_auto(x, weights, bias) # 传递未分片的输入
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
最后,您可以使用 shard_map
做同样的事情,使用 jax.lax.psum()
来表示矩阵乘积所需的跨分片集合操作:
from functools import partial
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x'), P('x', None), P(None)),
out_specs=P(None))
def layer_sharded(x, weights, bias):
return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
下一步#
本教程作为对 JAX 中分片和并行计算的简要介绍。
要深入了解每种 SPMD 方法,请查看以下文档: