使用Join上下文管理器进行不均匀输入的分布式训练
创建日期:2021年8月4日 | 最后更新:2023年1月9日 | 最后验证:2024年11月5日
作者: Andrew Gu
注意
在github上查看和编辑本教程。
注意
Join
在 PyTorch 1.10 中作为原型功能引入。此 API 可能会发生变化。
在本教程中,您将看到:
Join 上下文管理器的概述。
一个关于如何使用
DistributedDataParallel
上下文管理器的示例。一个示例,展示了如何同时使用
DistributedDataParallel
和ZeroRedundancyOptimizer
的上下文管理器。一个向上下文管理器传递关键字参数的示例。
深入探讨Join上下文管理器的工作原理。
一个展示如何使玩具类与上下文管理器兼容的示例。
什么是 Join
?
在分布式数据并行入门 - 基本用例中,您看到了使用DistributedDataParallel执行数据并行训练的通用框架。这在每个反向传递中隐式调度了全归约操作,以跨等级同步梯度。这种集体通信需要进程组中的所有等级参与,因此如果一个等级的输入较少,则其他等级将挂起或出错(取决于后端)。更一般地说,对于任何执行每次迭代同步集体通信的类,这个问题都会持续存在。
Join
是一个上下文管理器,用于围绕每个等级的训练循环,以便于处理输入不均匀的训练。该上下文管理器允许那些提前耗尽输入的等级(即提前join)来遮蔽那些尚未加入的等级执行的集体通信。通信被遮蔽的方式由钩子指定。
使用Join
与DistributedDataParallel
PyTorch的DistributedDataParallel与Join
上下文管理器开箱即用。以下是一个使用示例:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
这将产生以下输出(其中来自等级0和等级1的print()
可能会任意排序):
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
注意
DistributedDataParallel 在引入这个通用的 Join
上下文管理器之前,提供了自己的 join() 上下文管理器。在上述示例中,使用 with Join([model]):
等同于使用 with model.join():
。现有的 DistributedDataParallel.join()
的一个限制是它不允许多个参与类,例如 DistributedDataParallel
和 ZeroRedundancyOptimizer 一起使用。
使用Join
与DistributedDataParallel
和ZeroRedundancyOptimizer
Join
上下文管理器不仅适用于单个类,还可以与多个类一起使用。PyTorch 的 ZeroRedundancyOptimizer
也与该上下文管理器兼容,因此在这里,我们探讨如何修改前面的示例以同时使用 DistributedDataParallel
和 ZeroRedundancyOptimizer
:
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
这将产生与之前相同的输出。显著的变化是额外将ZeroRedundancyOptimizer
实例传递到Join()
中。
传递关键字参数
类可能提供关键字参数,这些参数在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel
提供了一个参数divide_by_initial_world_size
,它决定了梯度是除以初始世界大小还是有效世界大小(即未加入的等级数)。这些关键字参数可以直接传递给上下文管理器。
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
警告
传递给上下文管理器的关键字参数在所有参与的类之间共享。这不应该是一个限制,因为我们不期望有多个Joinable
需要相同参数的不同设置的情况。尽管如此,这是需要注意的一点。
Join
是如何工作的?
现在我们已经看了一些关于如何使用Join
上下文管理器的初步示例,让我们更深入地了解它是如何工作的。这将提供对其提供的全部功能的更深入理解,并为你准备使自己的自定义类兼容。在这里,我们将讨论Join
类以及支持类Joinable
和JoinHook
。
Joinable
首先,与Join
上下文管理器兼容的类必须继承自抽象基类Joinable
。特别是,一个Joinable
必须实现:
join_hook(self, **kwargs) -> JoinHook
这将返回JoinHook
实例给Joinable
,确定加入的进程应如何影子化由Joinable
执行的每次迭代的集体通信。
join_device(self) -> torch.device
这将返回一个设备,供Join
上下文管理器使用,以执行集体通信,例如torch.device("cuda:0")
或torch.device("cpu")
。
join_process_group(self) -> ProcessGroup
这将返回由Join
上下文管理器使用的进程组,以执行集体通信。
特别是,join_device
和 join_process_group
是必需的属性,以确保上下文管理器可以调度已加入和未加入进程之间的集体通信。一种用法是使用 all-reduce 在每次迭代中计算未加入进程的数量。另一种用法是实现 throw_on_early_termination=True
所需的机制,我们将在下面进一步解释。
DistributedDataParallel
和 ZeroRedundancyOptimizer
已经继承自 Joinable
并实现了上述方法,这就是为什么我们可以在前面的示例中直接使用它们。
Joinable
类应确保调用 Joinable
构造函数,因为它初始化了一个 JoinConfig
实例,该实例由上下文管理器内部使用以确保正确性。这将作为字段 _join_config
保存在每个 Joinable
中。
JoinHook
接下来,让我们分解JoinHook
类。一个JoinHook
提供了两个进入上下文管理器的入口点:
main_hook(self) -> None
当存在尚未加入的排名时,每个已加入的排名会重复调用此钩子。它的目的是在每个训练迭代中(例如,在一次前向传递、反向传递和优化器步骤中)遮蔽由Joinable
执行的集体通信。
post_hook(self, is_last_joiner: bool) -> None
这个钩子在所有等级加入后被调用。它传递了一个额外的bool
参数is_last_joiner
,该参数指示该等级是否是最后加入的之一。该参数可能对同步有用。
为了具体说明这些钩子可能是什么样子,提供的ZeroRedundancyOptimizer
主钩子会像平常一样执行优化器步骤,因为加入的排名仍然负责更新和同步其参数的分片,而提供的DistributedDataParallel
后钩子会从最后加入的排名之一广播最终更新的模型,以确保所有排名上的模型是相同的。
Join
最后,让我们看看这些如何适应Join
类本身。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
正如我们在前面的例子中看到的,构造函数接收参与训练循环的Joinable
列表。这些应该是在每次迭代中执行集体通信的类。
enable
是一个 bool
类型的值,如果你知道不会有不均匀的输入,可以将其设置为 False
,在这种情况下,上下文管理器将变得类似于 contextlib.nullcontext()
的空操作。这也可能会禁用参与 Joinable
的与连接相关的计算。
throw_on_early_termination
是一个 bool
类型的参数,可以设置为 True
,以便在检测到不均匀输入时,每个进程立即抛出异常。这对于不符合上下文管理器要求的情况非常有用,最常见的情况是当存在来自不同类的集体通信时,这些通信可能会任意交错,例如在使用带有 SyncBatchNorm
层的模型时使用 DistributedDataParallel
。在这种情况下,应将此参数设置为 True
,以便应用程序逻辑可以捕获异常并确定如何继续。
核心逻辑发生在
__exit__()
方法中,该方法在存在未加入的等级时循环,调用每个Joinable
的主钩子,然后一旦所有等级都加入后,调用它们的后钩子。主钩子和后钩子都是按照Joinable
传入的顺序进行迭代的。上下文管理器需要从未加入的进程中获取心跳信号。因此,每个
Joinable
类在每次迭代的集体通信之前应调用Join.notify_join_context()
。上下文管理器将确保只有第一个传入的Joinable
实际发送心跳信号。
警告
如上所述,关于throw_on_early_termination
,
Join
上下文管理器与某些类的组合不兼容。Joinable
的JoinHook
必须是可序列化的,因为每个钩子在继续下一个之前必须完全执行。换句话说,两个钩子不能重叠。此外,目前,主钩子和后钩子都以相同的确定性顺序迭代。如果这似乎是一个主要限制,我们可能会修改API以允许自定义顺序。
使玩具类与Join
一起工作
由于前一节介绍了几个概念,让我们通过一个简单的例子来实践一下。在这里,我们将实现一个类,用于计算在其排名加入之前所有排名中看到的输入数量。这应该提供一个基本思路,说明如何使您自己的类与Join
上下文管理器兼容。
具体来说,以下代码让每个等级打印出(1)在它加入之前所有等级看到的输入数量,以及(2)所有等级的总输入数量。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
由于rank 0看到5个输入,rank 1看到6个输入,因此产生以下输出:
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
一些需要强调的关键点:
一个
Counter
实例每次迭代执行一次全归约操作,因此主钩子也执行一次全归约操作以与其保持一致。Counter
类在其__call__()
方法的开头调用了Join.notify_join_context()
,因为这是在每次迭代的集体通信(即其全归约)之前的位置。is_last_joiner
参数用于确定后钩子中的广播源。我们传入
sync_max_count
关键字参数给上下文管理器,然后它被转发到Counter
的join钩子。