GPU 性能提示#

本文档专注于神经网络工作负载的性能优化技巧

矩阵乘法精度#

在最近的GPU代系中,例如Nvidia A100代系或更新的,使用 bfloat16 精度进行大多数计算可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例:

XLA 性能标志#

备注

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能标志 的页面。

XLA 标志的存在及其确切行为可能依赖于 jaxlib 版本。

截至 jaxlib==0.4.18 版本(发布于 2023年10月6日),设置这些 XLA 标志可以提高性能。有些与 GPU 之间的通信有关,因此仅在多设备上运行计算时相关,而其他一些则与每个设备上的代码生成有关。

这些中的某些可能在未来的版本中默认设置。

这些标志可以通过 XLA_FLAGS 外壳环境变量来设置。例如,我们可以将以下内容添加到 Python 文件的顶部:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

更多示例,请参阅 Nvidia GPU 上 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_enable_triton_softmax_fusion 此标志启用基于Triton代码生成的模式匹配的自动softmax融合。默认值为False。

  • –xla_gpu_triton_gemm_any 对任何它支持的GEMM(矩阵乘法)使用基于Triton的GEMM(matmul)发射器。默认值为False。

通信标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地重叠异步通信与计算。默认值为 False。

  • –xla_gpu_enable_pipelined_collectives 在使用流水线并行时,此标志启用与第 (i+1) 层权重 AllGather 与第 i 层计算的重叠。它还启用了与第 (i+1) 层权重 Reduce/ReduceScatter 与第 i 层计算的重叠。默认值为 False。当此标志开启时存在一些错误。

  • –xla_gpu_collective_permute_decomposer_threshold 这个标志在执行 GSPMD 流水线 时非常有用。设置一个非零阈值会将 CollectivePermute 分解为 CollectivePermuteReceiveDoneCollectivePermuteSendDone 对,从而可以在每个对应的 ReceiveDone/SendDone 对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为阈值 > 0,例如 --xla_gpu_collective_permute_decomposer_threshold=1024,可以启用此功能。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的 AllGather/ReduceScatter/AllReduce 组合成一个大的 AllGather/ReduceScatter/AllReduce,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载的 AllGather/ReduceScatter 阈值,考虑将它们调高,以便至少组合一个 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

NCCL 标志#

这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上的单主机多设备计算有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志目前似乎对多主机通信没有用处。

多进程#

我们建议每个GPU使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加快即时计算的速度。jax.distributed.initialize() API 在SLURM下运行时会自动理解这种配置。然而,这只是一个经验法则,根据您的使用情况,测试每个GPU一个进程和每个节点一个进程可能会有所帮助。