使用Tensor Parallel (TP)进行大规模Transformer模型训练
创建于:2024年4月19日 | 最后更新:2024年8月19日 | 最后验证:2024年11月5日
作者: Wanchao Liang, Tianyu Liu
注意
在github上查看和编辑本教程。
本教程演示了如何使用Tensor Parallel和Fully Sharded Data Parallel在数百到数千个GPU上训练一个大型Transformer类模型。
先决条件:
已安装 PyTorch 2.3.0 或更高版本,支持 CUDA/Linux
Tensor Parallel 如何工作?
Tensor Parallel (TP) 最初在 Megatron-LM 论文中提出,它是一种用于训练大规模Transformer模型的高效模型并行技术。本教程中提到的 Sequence Parallel (SP) 是 Tensor Parallel 的一种变体,它在序列维度上对 nn.LayerNorm
或 RMSNorm
进行分片,以进一步节省训练期间的激活内存。随着模型变得更大,激活内存成为瓶颈,因此在 Tensor Parallel 训练中,通常会将 Sequence Parallel 应用于 LayerNorm
或 RMSNorm
层。

图1. 展示了在Transformer模型的MLP和自注意力层上以Tensor Parallel风格进行的分片,其中注意力/MLP中的矩阵乘法通过分片计算进行(图片来源)
在高层次上,PyTorch Tensor Parallel 的工作方式如下:
分片初始化
确定要应用于每个层的
ParallelStyle
,并通过调用parallelize_module
来分片初始化的模块。并行化的模块将把它们的模型参数交换为DTensors,DTensor将负责使用分片计算来运行并行化的模块。
运行时前进/后退
根据用户为每个
ParallelStyle
指定的输入/输出DTensor布局,它将运行适当的通信操作来转换输入/输出的DTensor布局(例如allreduce
、allgather
和reduce_scatter
)。运行分片计算以节省并行化层的计算/内存(例如,
nn.Linear
,nn.Embedding
)。
何时以及为何应该应用Tensor Parallel
PyTorch 完全分片数据并行(FSDP)已经具备将模型训练扩展到特定数量 GPU 的能力。然而,当涉及到在模型大小和 GPU 数量方面进一步扩展模型训练时,会出现许多额外的挑战,可能需要将张量并行与 FSDP 结合使用:
随着世界大小(GPU数量)变得非常大(超过128/256个GPU),FSDP集合(如
allgather
)主要由环形延迟主导。 通过在FSDP之上实现TP/SP,可以将FSDP世界大小减少8倍,通过仅在主机间应用FSDP,从而相应地减少延迟成本。由于收敛性和GPU内存限制,达到了数据并行性的极限,无法将全局批处理大小提高到超过GPU数量,Tensor/Sequence Parallel 是唯一已知的方法来“估算”全局批处理大小并继续使用更多GPU进行扩展。这意味着模型大小和GPU数量都可以继续扩展。
对于某些类型的模型,当本地批量大小变小时,TP/SP可以产生更适合浮点运算(FLOPS)的矩阵乘法形状。
那么,在预训练时,达到这些限制有多容易?到目前为止,预训练一个拥有数十亿或数万亿标记的大型语言模型(LLM)可能需要数月时间,即使使用数千个GPU也是如此。
在大规模训练LLM时,总会遇到限制1。例如,Llama 2 70B使用2k个GPU训练了35天,需要在2k规模上进行多维并行。
当Transformer模型变得更大(例如Llama2 70B)时,它也会很快达到限制2。由于内存和收敛性的限制,即使本地
batch_size=1
也不能单独使用FSDP。例如,Llama 2的全局批量大小为1K,因此在2K GPU上不能单独使用数据并行。
如何应用张量并行
PyTorch Tensor Parallel APIs 提供了一组模块级原语(ParallelStyle
)来配置模型每一层的分片,包括:
ColwiseParallel
和RowwiseParallel
:以列或行的方式分片nn.Linear
和nn.Embedding
。SequenceParallel
: 在nn.LayerNorm
,nn.Dropout
,RMSNormPython
等上执行分片计算。PrepareModuleInput
和PrepareModuleOutput
: 配置模块输入/输出的分片布局,并进行适当的通信操作。
为了演示如何使用PyTorch原生的Tensor Parallel API,让我们来看一个常见的Transformer模型。在本教程中,我们使用最新的Llama2模型作为参考的Transformer模型实现,因为它在社区中也广泛使用。
由于Tensor Parallel将单个张量分散在一组设备上,我们需要首先设置分布式环境(例如NCCL通信器)。 Tensor Parallelism是一种类似于PyTorch DDP/FSDP的单程序多数据(SPMD)分片算法,它在底层利用PyTorch DTensor 来执行分片。它还利用DeviceMesh抽象(在底层管理ProcessGroups)进行设备管理和分片。 要了解如何使用DeviceMesh设置多维并行,请参阅本教程。Tensor Parallel通常在每个主机内工作,因此让我们首先初始化一个连接主机内8个GPU的DeviceMesh。
from torch.distributed.device_mesh import init_device_mesh
tp_mesh = init_device_mesh("cuda", (8,))
现在我们已经初始化了DeviceMesh,让我们详细看一下Llama 2模型架构,并了解我们应该如何进行张量并行分片。
这里我们关注核心的TransformerBlock
,其中Transformer模型通过堆叠相同的TransformerBlock
来扩展模型。
核心的TransformerBlock
由一个Attention
层和一个FeedForward
层组成。让我们首先看一下更简单的FeedForward
层。
对于FeedForward
层,它由三个线性层组成,执行SwiGLU风格的MLP,查看其前向函数:
# forward in the FeedForward layer
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
它同时执行w1
和w3
的矩阵乘法,然后使用组合的w1/w3线性投影结果进行w2
的矩阵乘法。这意味着我们可以使用Tensor Parallelism论文中的思想,以列方式分片w1/w3线性层,并以行方式分片w2
线性层,这样在所有三层结束时只进行一次allreduce
通信。使用PyTorch原生的Tensor Parallel,我们可以简单地创建一个parallelize_plan
,如下所示:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这就是我们如何使用PyTorch Tensor Parallel API配置FeedForward
层的分片方式。请注意,用户只需指定如何分片各个层,而通信(例如,allreduce
)将在后台自动进行。
接下来是Attention
层。它由wq
、wk
、wv
线性层组成,用于将输入投影到q
/k
/v
,然后使用wo
线性层执行注意力机制和输出投影。这里的张量并行旨在对q/k/v投影进行列向分片,对wo
线性投影进行行向分片。因此,我们可以将Attention计划添加到我们刚刚起草的tp_plan
中:
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这几乎是我们需要将Tensor Parallelism应用于TransformerBlock
的layer_tp_plan
。然而,我们需要注意的是,当按列分片线性层时,线性层的输出将在最后一个张量维度上分片,而按行分片的线性层直接接受在最后一个维度上分片的输入。
如果在按列分片的线性和按行分片的线性之间有更多的张量操作(如视图操作),我们需要调整相关的形状相关操作以适应分片形状。
对于Llama模型,在注意力层中有几个与形状相关的视图操作。特别是对于wq
/ wk
/ wv
线性层的列并行,激活张量在num_heads
维度上进行分片,因此我们需要将num_heads
调整为本地num_heads
。
最后,我们需要调用parallelize_module
API来使每个TransformerBlock
的计划生效。在底层,它将Attention
和FeedForward
层中的模型参数分配到DTensors,并在必要时为模型输入和输出(分别在每个模块之前和之后)注册通信钩子:
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
现在我们已经详细阐述了每个TransformerBlock
的分片计划,通常在第一个层中有一个nn.Embedding
和一个最终的nn.Linear
投影层,用户可以选择对第一个nn.Embedding
进行行分片或列分片,并对最后一个nn.Linear
投影层进行列分片,同时指定适当的输入和输出布局。
以下是一个示例:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)
注意
如果要分区的模型太大,无法放入CPU内存中,可以使用meta
设备初始化(例如,首先在meta设备上初始化模型,对层进行分片,然后具体化模型),或者在Transformer模型初始化期间逐层并行化TransformerBlock
。
将序列并行应用于LayerNorm/RMSNorm
层
序列并行工作在上述张量并行的基础上。与仅对Attention
模块和FeedForward
模块内的张量进行分片并保持其模块输入和输出(即前向传播中的激活和后向传播中的梯度)复制的张量并行相比,序列并行在序列维度上保持它们的分片状态。
在一个典型的TransformerBlock
中,前向函数结合了归一化层(LayerNorm
或RMSNorm
)、注意力层、前馈层和残差连接。例如:
# forward in a TransformerBlock
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
在大多数使用场景中,激活(和梯度)的形状在Attention
和FeedForward
模块之外是[batch size, sequence length, hidden dimension]
。在DTensor的语言中,序列并行使用Shard(1)
布局进行模块的前向/后向激活计算。
根据前面的代码示例,下面的代码展示了我们如何将序列并行应用于TransformerBlock
中的归一化层:
首先,让我们导入Sequence Parallel所需的依赖项:
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
SequenceParallel,
)
接下来让我们调整layer_tp_plan
以在RMSNorm
层上启用序列并行:
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
可以看到我们现在使用PrepareModuleInput
将模块输入布局从Shard(1)
修改为Replicate()
,并将它们的输出布局标记为Shard(1)
。
就像在张量并行中发生的那样,只需要指定输入和输出的张量分片布局,层之间的通信将自动发生。
请注意,使用序列并行时,我们假设TransformerBlock
的输入和输出始终在序列维度上进行分片,以便多个TransformerBlocks
可以无缝连接。
这可以通过明确指定起始nn.Embedding
层的输出和最终nn.Linear
投影层的输入为Shard(1)
来促进:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)
应用并行损失
损失并行是一种相关技术,用于在计算损失函数时节省内存和通信,因为模型输出通常非常大。在损失并行中,当模型输出在(通常很大的)词汇维度上分片时,可以高效地计算交叉熵损失,而无需将所有模型输出收集到每个GPU上。这不仅显著减少了内存消耗,还通过减少通信开销和并行进行分片计算来提高训练速度。下图简要说明了损失并行如何通过分片计算避免将所有模型输出收集到每个GPU上。
在PyTorch Tensor Parallel API中,可以通过上下文管理器loss_parallel
启用Loss Parallel,使用它可以直接使用torch.nn.functional.cross_entropy
或torch.nn.CrossEntropyLoss
而无需修改代码的其他部分。
要应用Loss Parallel,模型预测(通常形状为[batch size, sequence length, vocabulary size]
)应在词汇维度上进行分片。这可以通过标记最后一个线性投影层输出的布局来轻松完成:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
},
)
在上面的代码中,我们还在输出之前对规范层应用了序列并行。我们应用use_local_output=False
让输出保持为DTensor,以便与loss_parallel
上下文管理器一起工作。之后,可以简单地调用交叉熵损失函数,如下所示。请注意,反向计算也需要在上下文中进行。
import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel
pred = model(input_ids)
with loss_parallel():
# assuming pred and labels are of the shape [batch, seq, vocab]
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
loss.backward()
将张量并行与完全分片数据并行结合
既然我们已经展示了如何将Tensor/Sequence Parallel应用于模型,让我们也来看看Tensor Parallel和Fully Sharded Data Parallel如何协同工作。 由于Tensor Parallelism会引发阻塞计算的通信,我们希望确保它在快速通信通道(如NVLink)中运行。 在实践中,我们通常在每个主机内应用Tensor Parallel,并在主机之间应用Fully Sharded Data Parallel。
这种二维并行模式可以通过二维DeviceMesh轻松表达,我们只需要将每个“子”DeviceMesh传递给每个单独的并行API:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices
model = Model(...)
tp_plan = {...}
# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
这将使我们能够轻松地在每个主机内应用Tensor Parallel(主机内),并在主机间应用FSDP(主机间),而无需对Llama模型进行任何代码更改。 Tensor(模型)并行和数据并行技术相结合,提供了继续增加模型规模并使用大量GPU高效训练的能力。
结论
本教程演示了如何使用Tensor Parallel与Fully Sharded Data Parallel结合,在数百到数千个GPU上训练一个大型Transformer类模型。 它解释了如何将Tensor Parallel应用于模型的不同部分,而无需更改模型本身的代码。Tensor Parallel是一种用于大规模训练的高效模型并行技术。
要查看本教程中解释的完整端到端代码示例,请参考pytorch/examples仓库中的Tensor Parallel examples。