XLA 编译器标志列表#

介绍#

本指南简要概述了XLA以及XLA与Jax的关系。如需深入了解,请参阅XLA文档。接着,它列出了为优化Jax程序性能而设计的常用XLA编译器标志。

XLA: Jax 背后的强大引擎#

XLA(加速线性代数)是一个专门为线性代数设计的编译器,它在Jax的性能和灵活性中起着关键作用。它通过将你的Python/NumPy类代码转换并编译为高效的机器指令,使Jax能够为各种硬件后端(CPU、GPU、TPU)生成优化的代码。

Jax 利用 XLA 的 JIT 编译能力,在运行时将您的 Python 函数转换为优化的 XLA 计算。

在 Jax 中配置 XLA:#

你可以在运行Python脚本或colab笔记本之前,通过设置XLA_FLAGS环境变量来影响Jax中的XLA行为。

对于 Colab 笔记本:

使用 os.environ['XLA_FLAGS'] 提供标志:

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

对于Python脚本:

XLA_FLAGS 指定为命令行命令的一部分:

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要提示:

  • 在导入 Jax 或其他相关库之前设置 XLA_FLAGS。在后台初始化后更改 XLA_FLAGS 将无效,并且由于后台初始化时间未明确定义,通常在执行任何 Jax 代码之前设置 XLA_FLAGS 更为安全。

  • 尝试使用不同的标志来优化您特定用例的性能。

更多信息:

  • 关于XLA的完整且最新的文档可以在官方的 XLA文档 中找到。

  • 对于开源版XLA支持的后端(CPU、GPU),XLA标志及其默认值在 xla/debug_options_flags.cc 中定义,完整的标志列表可以在此处找到 这里

  • TPU 编译器标志不是 OpenXLA 的一部分,但以下列出了常用的选项。

  • 请注意,此标志列表并不详尽,且可能会有变动。这些标志是实现细节,不能保证它们会一直可用或保持当前的行为。

常见的 XLA 标志#

标志

类型

注释

xla_dump_to

字符串 (文件路径)

将放置预优化HLO文件和其他工件的文件夹(参见XLA工具)。

xla_enable_async_collective_permute

TristateFlag (真/假/自动)

将所有 collective-permute 操作重写为其异步变体。当设置为 auto 时,XLA 可以根据其他配置或条件自动开启异步 collective。

xla_enable_async_all_gather

TristateFlag (真/假/自动)

如果设置为 true,则启用异步全收集。如果设置为 auto,则仅在实现异步全收集的平台上启用。实现方式(如 BC-offload 或 continuation fusion)根据其他标志值选择。

xla_disable_hlo_passes

字符串(逗号分隔的传递名称列表)

要禁用的HLO传递的逗号分隔列表。这些名称必须与传递名称完全匹配(逗号周围没有空格)。

TPU XLA 标志#

标志

类型

注释

xla_tpu_enable_data_parallel_all_reduce_opt

布尔值 (true/false)

优化以增加用于数据并行分片的DCN(数据中心网络)all-reduce的重叠机会。

xla_tpu_data_parallel_opt_不同大小的操作

布尔值 (true/false)

即使在多个迭代的输出大小不匹配堆叠变量中的保存位置时,也支持跨多个迭代的数据并行操作的流水线。可能会增加内存压力。

xla_tpu_enable_async_collective_fusion

布尔值 (true/false)

启用融合异步集体通信与计算操作(输出/循环融合或卷积)的传递,这些操作在它们的-start和-done指令之间调度。

xla_tpu_enable_async_collective_fusion_fuse_all_gather

TristateFlag (真/假/自动)

启用在 AsyncCollectiveFusion 过程中融合所有收集操作。
如果设置为 auto,将根据目标启用。

xla_tpu_enable_async_collective_fusion_multiple_steps

布尔值 (true/false)

在 AsyncCollectiveFusion 过程中,允许在多个步骤(融合)中继续相同的异步集体操作。

xla_tpu_overlap_compute_collective_tc

布尔值 (true/false)

在单个 TensorCore 上启用计算和通信的重叠,即,相当于 MegaCore 融合的一个核心。

xla_tpu_spmd_rng_bit_generator_unsafe

布尔值 (true/false)

是否以分区方式运行 RngBitGenerator HLO,如果在计算的不同部分上使用不同的分片期望确定性结果,这是不安全的。

xla_tpu_megacore_fusion_allow_ags

布尔值 (true/false)

允许将 all-gathers 与卷积/all-reduces 融合。

xla_tpu_enable_ag_backward_pipelining

布尔值 (true/false)

管道全体收集(目前是超大规模全体收集)通过扫描循环向后进行。

GPU XLA 标志#

标志

类型

注释

xla_gpu_enable_latency_hiding_scheduler

布尔值 (true/false)

此标志启用延迟隐藏调度器,以有效重叠异步通信与计算。默认值为 False。

xla_gpu_enable_triton_gemm

布尔值 (true/false)

使用基于Triton的矩阵乘法。

xla_gpu_graph_level

标志 (0-3)

设置GPU图形级别的遗留标志。在新用例中使用 xla_gpu_enable_command_buffer。0 = 关闭;1 = 捕获融合和内存复制;2 = 捕获GEMM;3 = 捕获卷积。

xla_gpu_all_reduce_combine_threshold_bytes

整数 (字节)

这些标志调整何时将多个小的 AllGather / ReduceScatter / AllReduce 合并为一个大的 AllGather / ReduceScatter / AllReduce,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载的 AllGather / ReduceScatter 阈值,考虑将它们调整得足够高,以便至少合并一个 Transformer 层的权重 AllGather / ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

xla_gpu_all_gather_combine_threshold_bytes

整数 (字节)

参见上面的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_reduce_scatter_combine_threshold_bytes

整数 (字节)

参见上面的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_enable_pipelined_all_gather

布尔值 (true/false)

启用所有收集指令的流水线处理。

xla_gpu_enable_pipelined_reduce_scatter

布尔值 (true/false)

启用 reduce-scatter 指令的流水线处理。

xla_gpu_enable_pipelined_all_reduce

布尔值 (true/false)

启用所有归约指令的流水线处理。

xla_gpu_enable_while_loop_double_buffering

布尔值 (true/false)

为while循环启用双缓冲。

xla_gpu_enable_triton_softmax_fusion

布尔值 (true/false)

使用基于Triton的Softmax融合。

xla_gpu_enable_all_gather_combine_by_dim

布尔值 (true/false)

将所有具有相同收集维度或不考虑其维度的all-gather操作组合在一起。

xla_gpu_enable_reduce_scatter_combine_by_dim

布尔值 (true/false)

将reduce-scatter操作与相同维度或不考虑其维度进行组合。

附加阅读: