并行编程简介#

本教程作为JAX中单程序多数据(SPMD)代码的设备并行性入门。SPMD是一种并行性技术,其中相同的计算(例如神经网络的前向传播)可以在不同的输入数据(例如,一批中的不同输入)上,在多个设备(如若干GPU或Google TPU)上并行运行。

本教程涵盖三种并行计算模式:

利用这些SPMD的思维方式,您可以将为单个设备编写的函数转换为可以在多个设备上并行运行的函数。

如果您在Google Colab笔记本中运行这些示例,请确保您的硬件加速器是最新的Google TPU,通过检查您的笔记本设置:运行时 > 更改运行时类型 > 硬件加速器 > TPU v2(该选项提供八个可用设备)。

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

关键概念:数据分片#

所有以下分布式计算方法的关键是数据分片的概念,它描述了数据在可用设备上的布局。

JAX如何理解数据在设备上的布局?JAX的数据类型,即jax.Array不可变数组数据结构,表示具有跨一个或多个设备的物理存储的数组,并帮助使并行性成为JAX的核心特性。 jax.Array对象的设计考虑了分布式数据和计算。每个jax.Array都有一个相关的jax.sharding.Sharding对象,描述每个全局设备所需的全局数据的哪一片。当你从头创建一个jax.Array时,你还需要创建它的Sharding

在最简单的情况下,数组在单个设备上进行分片,如下所示:

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

为了更直观地展示存储布局,jax.debug 模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示数组是如何存储在单个设备的内存中的:

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      TPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建一个具有非平凡分片的数组,可以为数组定义一个 jax.sharding 规范,并将其传递给 jax.device_put()

在这里,定义一个 NamedSharding,该类指定了具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh 允许精确的设备放置:

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))

将这个 Sharding 对象传递给 jax.device_put(),您可以获得一个分片数组:

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   TPU 0       TPU 1       TPU 2       TPU 3    
                                                
                                                
                                                
                                                
                                                
   TPU 6       TPU 7       TPU 4       TPU 5    
                                                
                                                
                                                

这里的设备编号不是按数字顺序排列的,因为网格反映了设备的基础环形拓扑结构。

1. 通过 jit 自动并行化#

一旦你有了分片的数据,进行并行计算最简单的方法就是将数据传递给一个经过 jax.jit() 编译的函数!在 JAX 中,你只需要指定你希望代码的输入和输出如何进行分区,编译器将会自动处理:1) 内部所有内容的分区;2) 设备间的通信编译。

jit 背后的 XLA 编译器包括针对多设备计算优化的启发式算法。在最简单的情况下,这些启发式算法归结为 计算跟随数据

为了演示 JAX 中的自动并行化是如何工作的,下面是一个使用 jax.jit() 装饰的分阶段函数的示例:这是一个简单的元素级函数,每个分片的计算将在与该分片关联的设备上执行,输出也将以相同的方式进行分片:

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得愈加复杂,编译器决定如何最好地传播数据的分片。

在这里,您对x的前导轴进行求和,并可视化结果值如何在多个设备上存储(使用jax.debug.visualize_array_sharding()):

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 TPU 0,6  TPU 1,7  TPU 2,4  TPU 3,5 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果是部分复制的:即数组的前两个元素在设备 06 上复制,第二个元素在 17 上,依此类推。

2. 带约束的半自动分片#

如果您希望在特定计算中对分片有一定的控制,JAX 提供了 with_sharding_constraint() 函数。您可以使用 jax.lax.with_sharding_constraint()(代替 jax.device_put())以及 jax.jit() 来更好地控制编译器如何约束中间值和输出的分布。

例如,假设在上述 f_contract 中,您希望输出不再是部分复制的,而是希望在八个设备上完全分片:

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  mesh = jax.make_mesh((8,), ('x',))
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  
                                                                        
[48. 52. 56. 60. 64. 68. 72. 76.]

这为您提供了一个具有特定输出分片的函数。

3. 使用 shard_map 手动并行化#

在上述自动并行方法中,您可以编写一个函数,就像您在处理完整数据集一样,而 jit 将会在多个设备上拆分该计算。相比之下,通过 jax.experimental.shard_map.shard_map(),您需要编写一个处理单个数据分片的函数,shard_map 将构建完整函数。

shard_map 通过在特定的 mesh 设备上映射一个函数而工作(shard_map 在分片上进行映射)。在下面的示例中:

  • 与之前一样,jax.sharding.Mesh 允许精确的设备放置,轴名称参数用于逻辑和物理轴名称。

  • in_specs 参数确定分片大小。out_specs 参数标识如何将块重新组装在一起。

注意: jax.experimental.shard_map.shard_map() 代码可以在 jax.jit() 内部工作,如果您需要的话。

from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

您编写的函数只能“看到”数据的单个批次,您可以通过打印设备本地形状来检查这一点:

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

因为每个函数只“看见”设备本地的数据部分,这意味着像聚合这样的函数需要额外考虑。

例如,这里是一个shard_mapjax.numpy.sum()的样子:

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

您的函数 f 在每个分片上单独操作,结果的求和反映了这一点。

如果您想在分片之间求和,需要通过集合操作明确请求,例如 jax.lax.psum()

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

因为输出不再具有分片维度,设置 out_specs=P()(请记住,out_specs 参数确定了在 shard_map 中如何将块重新组合在一起)。

比较三种方法#

在我们心中牢记这些概念,让我们比较一下简单神经网络层的三种方法。

首先像这样定义你的标准函数:

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

您可以使用 jax.jit() 以分布式方式自动运行此操作,并传递适当分片的数据。

如果您以相同的方式对 xweights 的首轴进行分片,那么矩阵乘法将会自动并行执行:

mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

或者,您可以在函数中使用 jax.lax.with_sharding_constraint() 来自动分配未分片的输入:

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # 传递未分片的输入
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

最后,您可以使用 shard_map 做同样的事情,使用 jax.lax.psum() 来表示矩阵乘积所需的跨分片集合操作:

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

下一步#

本教程作为对 JAX 中分片和并行计算的简要介绍。

要深入了解每种 SPMD 方法,请查看以下文档: