GPU 性能提示#
本文档专注于神经网络工作负载的性能优化技巧
矩阵乘法精度#
在最近的GPU代系中,例如Nvidia A100代系或更新的,使用 bfloat16
精度进行大多数计算可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
实例化 Dense
层。以下是一些代码示例:
在 Flax LM1B 示例 中,
Dense
模块 使用可配置的 dtype 实例化,其 默认值 为 bfloat16。在 MaxText 中,
DenseGeneral
模块也 通过可配置的 dtype 实例化,其 默认值为 bfloat16。
XLA 性能标志#
备注
JAX-Toolbox 还有一个关于 NVIDIA XLA 性能标志 的页面。
XLA 标志的存在及其确切行为可能依赖于 jaxlib
版本。
截至 jaxlib==0.4.18
版本(发布于 2023年10月6日),设置这些 XLA 标志可以提高性能。有些与 GPU 之间的通信有关,因此仅在多设备上运行计算时相关,而其他一些则与每个设备上的代码生成有关。
这些中的某些可能在未来的版本中默认设置。
这些标志可以通过 XLA_FLAGS
外壳环境变量来设置。例如,我们可以将以下内容添加到 Python 文件的顶部:
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
更多示例,请参阅 Nvidia GPU 上 Pax 训练推荐的 XLA 标志。
代码生成标志#
–xla_gpu_enable_triton_softmax_fusion 此标志启用基于Triton代码生成的模式匹配的自动softmax融合。默认值为False。
–xla_gpu_triton_gemm_any 对任何它支持的GEMM(矩阵乘法)使用基于Triton的GEMM(matmul)发射器。默认值为False。
通信标志#
–xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地重叠异步通信与计算。默认值为 False。
–xla_gpu_enable_pipelined_collectives 在使用流水线并行时,此标志启用与第 (i+1) 层权重
AllGather
与第 i 层计算的重叠。它还启用了与第 (i+1) 层权重Reduce
/ReduceScatter
与第 i 层计算的重叠。默认值为 False。当此标志开启时存在一些错误。–xla_gpu_collective_permute_decomposer_threshold 这个标志在执行 GSPMD 流水线 时非常有用。设置一个非零阈值会将
CollectivePermute
分解为CollectivePermuteReceiveDone
和CollectivePermuteSendDone
对,从而可以在每个对应的ReceiveDone
/SendDone
对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为阈值 > 0,例如--xla_gpu_collective_permute_decomposer_threshold=1024
,可以启用此功能。–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的
AllGather
/ReduceScatter
/AllReduce
组合成一个大的AllGather
/ReduceScatter
/AllReduce
,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载的AllGather
/ReduceScatter
阈值,考虑将它们调高,以便至少组合一个 Transformer 层的权重AllGather
/ReduceScatter
。默认情况下,combine_threshold_bytes
设置为 256。
NCCL 标志#
这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上的单主机多设备计算有用:
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
这些 NCCL 标志可以提高单主机通信速度。这些标志目前似乎对多主机通信没有用处。
多进程#
我们建议每个GPU使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加快即时计算的速度。jax.distributed.initialize()
API 在SLURM下运行时会自动理解这种配置。然而,这只是一个经验法则,根据您的使用情况,测试每个GPU一个进程和每个节点一个进程可能会有所帮助。