持久编译缓存#

JAX 有一个可选的编译程序磁盘缓存。如果启用,JAX 会将编译程序的副本存储在磁盘上,这可以在重复运行相同或相似任务时节省重新编译的时间。

用法#

快速开始#

import jax
import jax.numpy as jnp

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

@jax.jit
def f(x):
  return x + 1

x = jnp.zeros((2, 2))
f(x)

设置缓存目录#

当设置了 缓存位置 时,编译缓存被启用。这应该在第一次编译之前完成。设置位置如下:

(1) 使用环境变量

在shell中,运行脚本之前:

export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

或者在 Python 脚本的顶部:

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

(2) 使用 jax.config.update()

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

(3) 使用 set_cache_dir()

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")

缓存阈值#

  • jax_persistent_cache_min_compile_time_secs: 只有当编译时间超过指定值时,计算结果才会被写入持久缓存。默认值为1.0秒。

  • jax_persistent_cache_min_entry_size_bytes:将被缓存在持久编译缓存中的条目的最小大小(以字节为单位):

    • -1: 禁用大小限制并防止覆盖。

    • 保留为默认值 (0) 以允许覆盖。通常,覆盖将确保最小尺寸对于用于缓存的文件系统是最优的。

    • > 0: 实际所需的最小尺寸;无覆盖。

请注意,函数要被缓存,两个条件都需要满足。

Google Cloud#

在Google Cloud上运行时,编译缓存可以放置在Google Cloud Storage(GCS)存储桶中。我们推荐以下配置:

  • 在运行工作负载的同一区域中创建存储桶。

  • 在与工作负载的虚拟机(VM)相同的项目中创建存储桶。确保设置权限,使虚拟机(VM)可以写入存储桶。

  • 对于较小的负载,不需要复制。较大的负载可能会从复制中受益。

  • 使用“Standard”作为存储桶的默认存储类别。

  • 将软删除策略设置为最短:7天。

  • 将对象生命周期设置为工作负载运行的预期持续时间。例如,如果预计工作负载运行10天,则将对象生命周期设置为10天。这应涵盖在整个运行期间发生的重启。使用 age 作为生命周期条件,Delete 作为操作。详情请参阅 对象生命周期管理。如果未设置对象生命周期,缓存将继续增长,因为没有实现驱逐机制。

  • 所有加密策略都受支持。

假设 gs://jax-cache 是 GCS 存储桶,设置缓存位置如下:

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")

工作原理#

缓存键是编译函数的签名,包含以下参数:

  • 由JAX函数哈希的非优化HLO捕获的函数执行的计算

  • jaxlib 版本

  • 相关的 XLA 编译标志

  • 设备配置通常通过设备数量和设备拓扑结构来捕获。目前对于GPU,拓扑结构仅包含GPU名称的字符串表示。

  • 用于压缩编译后可执行文件的压缩算法

  • jax._src.cache_key.custom_hook() 生成的字符串。这个函数可以被重新赋值为用户定义的函数,从而可以改变生成的字符串。默认情况下,这个函数总是返回一个空字符串。

多节点缓存#

程序第一次运行时(持久缓存为冷/空),所有进程都会进行编译,但只有全局通信组中排名为0的进程会写入持久缓存。在后续运行中,所有进程都会尝试从持久缓存中读取,因此持久缓存必须位于共享文件系统(例如:NFS)或远程存储(例如:GFS)中。如果持久缓存仅对排名为0的进程本地化,那么在后续运行中,除排名为0的进程外的所有进程都会因编译缓存未命中而再次编译。

记录缓存活动#

检查持久编译缓存中究竟发生了什么,对于调试可能会有帮助。以下是一些如何开始的建议。

用户可以通过放置相关源文件来启用日志记录。

import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"

在脚本的顶部。

检查缓存未命中#

为了检查和理解为什么会出现缓存未命中,JAX 包含了一个配置标志,可以启用所有缓存未命中(包括持久编译缓存未命中)及其解释的日志记录。虽然目前这仅针对追踪缓存未命中实现,但最终目标是解释所有缓存未命中。可以通过设置以下配置来启用此功能。

jax.config.update("jax_explain_cache_misses", True)

陷阱#

目前发现了一些陷阱:

  • 目前,持久缓存不适用于具有主机回调的函数。在这种情况下,完全避免缓存。

    • 这是因为 HLO 包含一个指向回调的指针,即使计算和计算基础设施完全相同,每次运行时也会发生变化。

  • 目前,持久化缓存不适用于使用实现自定义分区的基础类型的函数。

    • 函数的HLO包含一个指向custom_partitioning回调的指针,并且在不同运行中会导致相同的计算产生不同的缓存键。

    • 在这种情况下,缓存仍然会进行,但每次都会生成不同的键,从而使缓存失效。

围绕 custom_partitioning 工作#

如前所述,编译缓存不适用于由实现 custom_partitioning 的原语组成的函数。然而,对于那些确实实现了 custom_partitioning 的原语,可以使用 shard_map 来绕过 custom_partitioning,从而使编译缓存按预期工作:

让我们假设我们有一个函数 F ,它实现了一个layernorm,然后使用一个实现 custom_partitioning 的原始 LayerNorm 进行矩阵乘法:

import jax

def F(x1, x2, gamma, beta):
   ln_out = LayerNorm(x1, gamma, beta)
   return ln_out @ x2

如果我们仅仅编译这个函数而没有使用 shard_map,那么每次运行相同代码时,layernorm_matmul_without_shard_map 的缓存键都会不同:

layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)

然而,如果我们用 shard_map 包装 layernorm 原语并定义一个执行相同计算的函数 G,那么 layernorm_matmul_with_shard_map 的缓存键每次都会相同,尽管 LayerNorm 实现了 custom_partitioning

import jax
from jax.experimental.shard_map import shard_map

def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
   ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta)
   return ln_out @ x2

ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)

请注意,实现 custom_partitioning 的原语必须被包裹在 shard_map 中以实现此解决方案。仅将外部函数 F 包裹在 shard_map 中是不够的。