分布式数据加载#
本高级指南展示了如何在分布式数据加载时执行操作——当你在 多主机或多进程环境 中运行 JAX,并且 JAX 计算所需的数据分布在多个进程中时。本文档涵盖了如何思考分布式数据加载的整体方法,然后如何将其应用于 数据并行(较简单)和 模型并行(较复杂)的工作负载。
分布式数据加载通常更高效(数据在进程间分割),但与以下替代方案相比也更复杂:1) 在一个进程中加载全部全局数据,将其分割并通过RPC发送所需部分到其他进程;2) 在所有进程中加载全部全局数据,每个进程仅使用所需部分。加载全部全局数据通常更简单但成本更高。例如,在机器学习中,训练循环可能会在等待数据时被阻塞,并且每个进程都会使用额外的网络带宽。
备注
在使用分布式数据加载时,重要的是每个设备(例如,每个GPU或TPU)都能访问到运行计算所需的输入数据分片。这通常使得分布式数据加载比上述替代方案更复杂且更具挑战性,难以正确实现。如果错误的数据分片最终出现在错误的设备上,计算仍可以无错误地运行,因为计算无法知道输入数据“应该”是什么。然而,由于输入数据与预期不同,最终结果通常会不正确。
加载 jax.Array
的一般方法#
考虑一个从非JAX生成的原始数据创建单个 jax.Array
的情况。这些概念适用于加载批量数据记录之外,例如任何不是直接由JAX计算生成的多进程 jax.Array
。示例包括:1) 从检查点加载模型权重;或 2) 加载一个大型空间分片图像。
每个 jax.Array
都有一个关联的 Sharding
,它描述了每个全局设备所需的全局数据的哪个分片。当你从头创建一个 jax.Array
时,你还需要创建它的 Sharding
。这就是 JAX 如何理解数据在设备之间的布局方式。你可以创建任何你想要的 Sharding
。实际上,你通常会根据你正在实现的并行策略来选择一个 Sharding
(你将在本指南后面的部分更详细地了解数据和模型并行性)。你也可以根据每个进程中原始数据的生成方式来选择一个 Sharding
。
一旦你定义了一个 Sharding
,你可以使用 addressable_devices()
来提供一个设备列表,这些设备需要在当前进程中加载数据。(注意:术语“可寻址设备”是“本地设备”的更一般版本。目标是确保每个进程的数据加载器为该进程的所有本地设备提供正确的数据。)
示例#
例如,考虑一个 (64, 128)
的 jax.Array
,你需要将其分片到4个进程中,每个进程有2个设备(总共8个设备)。这将导致8个唯一的数据分片,每个设备一个。有很多方法可以对这个 jax.Array
进行分片。你可以在 jax.Array
的第二个维度上执行1D分片,给每个设备一个 (64, 16)
的分片,如下所示:
在上图中,每个数据分片都有自己的颜色,以指示哪个进程需要加载该分片。例如,您假设进程 0
的 2 个设备包含分片 A
和 B
,对应于全局数据的第一个 (64, 32)
部分。
你可以选择不同的分片分布到设备上。例如:
这里是另一个例子 — 一个二维分片:
然而,无论 jax.Array
是如何分片的,您都必须确保每个进程的数据加载器被提供/加载了全局数据所需的片段。有几种高级方法可以实现这一点:1) 在每个进程中加载全局数据;2) 使用每个设备的数据管道;3) 使用每个进程的整合数据管道;4) 以某种方便的方式加载数据,然后在计算内部重新分片。
选项 1:在每个进程中加载全局数据#
使用此选项,每个进程:
加载所需的完整值;并且
仅将所需的碎片传输到该进程的本地设备。
这不是一种高效的分布式数据加载方法,因为每个进程都会丢弃其本地设备不需要的数据,并且总摄入的数据量可能超过必要。但这个选项是可行的,并且实现相对简单,而对于某些工作负载(例如,如果全局数据量较小),性能开销可能是可以接受的。
选项 2:使用每个设备的数据管道#
在这个选项中,每个进程为其每个本地设备设置一个数据加载器(即,每个设备为其所需的数据分片获取自己的数据加载器)。
这在数据加载方面是高效的。有时,考虑每个设备独立处理比一次性考虑一个进程的所有本地设备更为简单(参见下面的 选项3:使用一个综合的每进程数据管道)。然而,拥有多个并发数据加载器有时会导致性能问题。
选项 3:使用一个整合的每进程数据管道#
如果你选择此选项,每个进程:
设置一个单一的数据加载器,该加载器加载其所有本地设备所需的数据;然后
在传输到每个本地设备之前,对本地数据进行分片。
这是分布式加载的最有效方式。然而,这也是最复杂的,因为需要逻辑来确定每个设备需要哪些数据,并创建一个只加载所有这些数据(理想情况下,不加载任何额外数据)的单一数据加载。
选项 4:以某种方便的方式加载数据,在计算内部重新分片#
这个选项解释起来更具挑战性,但通常比上面的选项(1到3)更容易实现。
设想一个场景,其中设置加载器来加载完全符合你需要的数据(无论是按设备还是按进程)非常困难,甚至是不可能的。然而,仍然可能为每个进程设置一个数据加载器,该加载器加载 1 / num_processes
的数据,只是分片不正确。
然后,继续你之前提到的2D示例分片,假设每个进程加载数据的一列更容易:
然后,您可以使用表示每列数据的 Sharding
创建一个 jax.Array
,将该数据直接传递到计算中,并使用 jax.lax.with_sharding_constraint()
立即将列分片输入重新分片到所需的分片。由于数据在计算内部重新分片,它将在加速器通信链路上重新分片(例如,TPU ICI 或 NVLink)。
选项4与选项3(使用一个统一的过程数据管道)有类似的好处:
每个进程仍然有一个单独的数据加载器;并且
全局数据在所有进程中只加载一次;并且
全局数据具有额外的优势,即在数据加载方式上提供了更大的灵活性。
然而,这种方法使用加速器互连带宽来执行重新分片,这可能会减慢某些工作负载的速度。选项4还要求输入数据除了目标 Sharding
外,还必须表示为单独的 Sharding
。
复制#
复制描述了一个过程,其中多个设备具有相同的数据分片。上述一般选项(选项1至4)在复制中仍然有效。唯一的区别是某些过程可能会最终加载相同的数据分片。本节描述了完全复制和部分复制。
完全复制#
全量复制 是一个过程,其中所有设备都有数据的完整副本(即,数据“分片”是整个数组值)。
在下面的示例中,由于总共有8个设备(每个进程2个),您最终将得到完整数据的8个副本。每个数据副本都是未分片的,即副本存在于单个设备上:
部分复制#
部分复制 描述了一个过程,其中数据有多个副本,并且每个副本在多个设备上分片。对于给定的数组值,通常有许多可能的方式来执行部分复制(注意:对于给定的数组形状,总是有一个完全复制的 Sharding
)。
以下是两个可能的示例。
在下面的第一个示例中,每个副本被分片到进程的两个本地设备上,总共4个副本。这意味着每个进程都需要加载完整的全局数据,因为其本地设备将拥有数据的完整副本。
在下面的第二个示例中,每个副本仍然分布在两个设备上,但每对设备分布在两个不同的进程中。进程 0
(粉色)和进程 1
(黄色)都需要加载数据的第一行,而进程 2
(绿色)和进程 3
(蓝色)都需要加载数据的第二行:
既然你已经了解了创建 jax.Array
的高级选项,让我们将它们应用于机器学习应用的数据加载。
数据并行#
在 纯数据并行 (不包括模型并行)中:
你在每个设备上复制模型;并且
每个模型副本(即每个设备)接收不同的每个副本数据批次。
当将输入数据表示为单个 jax.Array
时,该数组包含此步骤中所有副本的数据(这称为 全局批次),其中 jax.Array
的每个分片包含一个每个副本的批次。您可以将其表示为所有设备上的 1D 分片(请查看下面的示例)——换句话说,全局批次由所有每个副本的批次沿着批次轴连接在一起组成。
应用此框架,您可以得出结论,进程 0
应获得全局批次的第一季度(8个中的2个),进程 1
应获得第二季度,依此类推。
但是你怎么知道第一季度是什么?而且你如何确保进程 0
获得第一季度?幸运的是,数据并行有一个非常重要的技巧,这意味着你不必回答这些问题,并且使整个设置更简单。
关于数据并行的重要技巧#
诀窍在于你不需要关心哪个副本批量落在哪个副本上。因此,哪个进程加载批量并不重要。原因是,由于每个设备对应一个执行相同操作的模型副本,所以全局批量中哪个设备获得哪个副本批量并不重要。
这意味着你可以自由地重新排列全局批次中的每个副本批次。换句话说,你可以自由地随机化每个设备获取的数据分片。
例如:
通常,如上所示重新排列 jax.Array
的数据分片并不是一个好主意 – 你实际上是在置换 jax.Array
的值!然而,对于数据并行性,全局批次顺序并不重要,正如之前提到的,你可以自由地重新排列全局批次中的每个副本批次。
这简化了数据加载,因为这意味着每个设备只需要一个独立的每副本批次流,这可以通过为每个进程创建一个独立的管道并在生成的每进程批次中将其分块为每副本批次来轻松地在大多数数据加载器中实现。
这是 Option 2: 按进程整合的数据管道 的一个实例。你也可以使用其他选项(如 0、1 和 3,这些在本文档前面部分有介绍),但这个选项相对简单且高效。
以下是如何使用 tf.data 实现此设置的示例:
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
[np.ones((16, 3)) * i for i in range(100)])
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch, jax.local_device_count())
# Thanks to the very important trick about data parallelism, no need to care what
# order the devices appear in the sharding.
sharding = jax.sharding.PositionalSharding(jax.devices())
# PositionalSharding must have same rank as data being sharded.
sharding = sharding.reshape((jax.device_count(),) +
(1,) * (per_process_batch.ndim - 1))
global_batch_size = per_replica_batch_size * jax.device_count()
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
global_batch_array = jax.make_array_from_single_device_arrays(
global_batch_shape, sharding,
# Thanks again to the very important trick, no need to care which device gets
# which per-replica batch.
arrays=[jax.device_put(batch, device)
for batch, device
in zip(per_replica_batches, sharding.addressable_devices)])
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
数据 + 模型并行#
在 模型并行 中,您将每个模型副本分片到多个设备上。如果您使用 纯模型并行(不使用数据并行):
只有一个模型副本在所有设备上分片;并且
数据通常会在所有设备上完全复制。
本指南考虑了一个使用 数据并行和模型并行 的案例:
你将每个多模型副本分片到多个设备上;并且
你在每个模型副本上部分复制数据——同一模型副本中的每个设备获取相同的每个副本批次,而跨模型副本的设备获取不同的每个副本批次。
进程内的模型并行#
为了数据加载的目的,最简单的方法可以是在单个进程的本地设备中对每个模型副本进行分片。
在这个例子中,我们改为每个进程使用4个设备(而不是每个进程使用2个设备)。考虑一个场景,其中每个模型副本在单个进程的2个本地设备上进行分片。这将导致每个进程有2个模型副本,总共4个模型副本,如下所示:
这里,再次,输入数据表示为一个单一的 jax.Array
,具有1D分片,其中每个分片是一个每个副本的批次,但有一个例外:
与纯数据并行的情况不同,您引入了部分复制,并对1D分片的全局批次制作了2份副本。
这是因为每个模型副本由2个设备组成,每个设备都需要一个副本批次的副本。
将每个模型副本保持在单个进程中可以使事情变得更简单,因为你可以重用上述纯数据并行设置,除了你还需要复制每个副本的批次:
备注
在将每个副本的批次复制到正确的设备上也是非常重要的! 虽然数据并行性的重要技巧意味着你不在乎哪个批次最终在哪个副本上,但你确实关心一个副本只得到一个批次。
例如,这是可以的:
然而,如果你不注意将每个批次加载到哪个本地设备上,你可能会意外地创建未复制的数据,尽管 Sharding
(以及并行策略)表明数据是复制的:
如果你不小心创建了一个在单个进程中应该被复制的未复制数据的 jax.Array
,JAX 将会抛出一个错误(不过,对于跨进程的模型并行来说,这并不总是成立的;请参见下一节)。
以下是如何使用 tf.data
实现进程模型并行和数据并行的示例:
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
# (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = (per_process_batch_size //
num_model_replicas_per_process)
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch,
num_model_replicas_per_process)
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
num_processes = len(set(d.process_index for d in replica_devices))
assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_size = per_replica_batch_size * num_model_replicas_total
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
# Create the final jax.Array using jax.make_array_from_callback. The callback
# will be called for each local device, and passed the N-D numpy-style index
# that describes what shard of the global data that device should receive.
#
# You don't need care exactly which index is passed in due to the very important data
# parallelism, but you do use the index argument to make sure you replicate each
# per-replica batch correctly -- the `index` argument will be the same for
# devices in the same model replica, and different for devices in different
# model replicas.
index_to_batch = {}
def callback(index: tuple[slice, ...]) -> np.ndarray:
# Python `slice` objects aren't hashable, so manually create dict key.
index_key = tuple((slice_.start, slice_.stop) for slice_ in index)
if index_key not in index_to_batch:
# You don't care which per-replica batch goes to which replica, just take the
# next unused one.
index_to_batch[index_key] = per_replica_batches[len(index_to_batch)]
return index_to_batch[index_key]
global_batch_array = jax.make_array_from_callback(
global_batch_shape, sharding, callback)
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
跨进程的模型并行#
当模型副本分布在不同的进程中时,情况会变得更加有趣,无论是:
因为单个副本无法适应一个进程;或者
因为设备分配并不是那样设置的。
例如,回到之前4个进程每个进程2个设备的设置,如果你像这样分配设备给副本:
这与之前的每个进程模型并行示例采用了相同的并行策略——4个模型副本,每个副本在2个设备上分片。唯一的区别在于设备分配——每个副本的两个设备分布在不同的进程中,每个进程只负责每个副本批次的一个副本(但有两个副本)。
像这样将模型副本分发到不同的进程中,看起来可能是一种随意且不必要的行为(在这个例子中可以说是这样),但在实际部署中,可能会采用这种设备分配方式,以充分利用设备间的通信链接。
数据加载现在变得更加复杂,因为需要在进程之间进行额外的协调。在纯数据并行和每个进程的模型并行情况下,每个进程加载一个唯一的数据流是很重要的。现在某些进程必须加载相同的数据,而有些则必须加载不同的数据。在上面的例子中,进程 0
和 2
(分别用粉色和绿色表示)必须加载相同的 2 个每副本批次,而进程 1
和 3
(分别用黄色和蓝色表示)也必须加载相同的 2 个每副本批次(但与进程 0
和 2
的批次不同)。
此外,重要的是每个进程不会混淆其每个副本的两个批次。虽然你不关心哪个批次落在哪个副本上(数据并行的非常重要的技巧),但你需要确保副本中的所有设备都获得相同的批次。例如,这是不好的:
备注
截至2023年8月,JAX 无法检测 jax.Array
在进程间的分片是否应被复制但未被复制,并且在计算运行时会产生错误结果。因此,请务必小心不要这样做!
要在每个设备上获得正确的每个副本批处理,您需要将全局输入数据表示为以下 jax.Array
: