jax.extend: 一个用于扩展的模块#

@froystig, @sharadmv, @jakevdp, @yashk2810

2023年5月

import jax.extend as jex

几个项目依赖于 JAX 的代码库内部,通常是为了使用其核心机制(例如,编写一个 对其 IR 的转换)或扩展它(例如,定义新的原语)。这些依赖面临的两个挑战是(a)我们的内部结构并非都为外部使用而设计,(b)绕过 JAX 的公共 API 是 不受支持的。换句话说,我们的内部结构经常被用作库,但既没有结构化也没有像库那样更新。

本提案考虑 引入一个 jax.extend 模块,该模块定义了 JAX 内部组件的库视图。我们将其视为第二级 API,仍然保证 没有兼容性政策,但希望在发生更改时更容易发现这些变化。

jax.extend 的受众包括与 JAX 相关的 Python 库,如 Oryxjax-triton 以及其他许多库,以及那些实验性项目,如函数变换、自动微分系统、数值编程的编译器前端等。

本说明概述了 jax.extend 可能的外观,目前和最终。它没有详细阐述,而是建议我们开始 迭代开发 该模块。

请注意,jax.extendjax.experimental 不同,后者是新功能和正在进行中的想法的试验场。通常,jax.experimental 中的工作最终会进入另一个 JAX 模块,或者被完全移除。

无兼容性政策#

为了保持开发开销低,jax.extend 将不遵循公共的 API 兼容性 政策。它将不承诺任何弃用窗口,也不保证版本之间的向后兼容性。每个版本都可能破坏现有的调用者,而没有简单的补救措施(例如,没有重新引入先前行为的标志)。我们将依赖 更新日志 来指出这些变化。

需要定期升级代码以配合 JAX 发布的 jax.extend 调用者可能会发现,在发布之间固定 JAX 版本作为中间步骤是有用的。这是当前依赖 JAX 内部的项目中的常见习惯。不同的是,现在有了变更日志公告的帮助,并且在库设计和命名方面有了更好的意图。

迭代开发#

没有兼容性政策使得开始实施变得更容易:第一天,我们可以从内部包如 jax._src 和今天的 jax.corejax.interpreters 中移动一些符号。然后我们可以从那里迭代改进。

可能的模块概述#

我们可以想象,最终 jax.extend 将包含以下模块:

  • core – 基本元素,Jaxpr IR 等。

  • interpreters – 核心转换(例如自动微分、批处理)和降低。

  • random – 随机位生成、密钥分割和折叠、密钥数组。

  • sharding – 围绕分布式数组的额外功能。

我们最初在模块中可能还有其他符号,例如 jex.api_util,因为我们正在努力移除或替换它们。其他符号将随着时间的推移决定。例如,jex.lib 可能提供一个进入 jaxlib 的入口点(并且在短期内会这样做),但我们是否希望长期保留它尚不清楚。

以下是对这些可能包含内容的初步思考。

jax.extend.core#

这应该至少使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...) 的输出)。支持这一点可能涉及提供:

  • 访问现有的核心系统原语,例如今天的 jax._src.lax.add_p

  • 访问IR类型,例如当前的 jax._src.core.ShapedArray

  • 用于检查和美化打印 jaxprs 的函数。

  • 用于显式构建 jaxprs 的函数,而不是通过 jax.make_jaxpr (或不!)分阶段构建 Python 函数。

在初始化时,该模块将包含比定义基本元素和规则所需的更多的符号,包括在设置 “最终样式转换” 中使用的各种名称,例如当前的 jax._src.core.TraceTracer 类。我们可以重新考虑 jex.core 是否也应该支持最终样式扩展以及初始样式方法,以及它是否可以通过比完全暴露 TraceTracer 更窄的API来实现这一点。Oryx 可能会帮助指导这些决策。

我们也可以考虑将 make_jaxpr 本身迁移到 jex.core

jax.extend.interpreters#

此模块将提供一种注册各种基本变换规则的方法——定义它们在自动微分、批处理、降低等操作下的行为。

它最初将反映 jax._src.interpreters 中的模块,提供 adbatchingpartial_eval(用于将 Python 分阶段转换为 Jaxpr,以及在 AD 中的线性化)、mlirpxlaxla。前三个可能可以通过 jex.core 中的单一原语扩展 API 替换。后三个用于降低的模块,可能会被简化为一个模块。

今天,为了编写转换规则,例如用于自动微分和批处理,调用者可能需要与跟踪器相关的符号,例如 JVPTracerBatchTracer。这可能在以后可以避免,并允许我们从 jex 中移除跟踪器类型。

此模块加上 jex.core 应当足以复制今天的自定义原始教程(例如 我们的dfm的)。例如,定义一个原始操作及其在 jax.jit 下的行为可以如下进行(在短期内):

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

此模块可以公开我们定义新RNG实现的机制,以及用于处理PRNG密钥内部(参见问题 #9263)的函数,例如当前的 jax._src.prng.random_wraprandom_unwrap

它还可以公开支持内置RNG实现的键控哈希函数,例如 jax._src.prng.threefry_2x32

jax.extend.sharding#

此模块可能暴露用于分片分布式数组的底层工具。

我们目前只考虑了一项内容。XLA 编译器的数组分片格式比 JAX 提供的那些 更具表达力。我们可以将其作为 jex.sharding.XlaOpShardingProto 提供,对应于今天的 jax._src.lib.xla_client.OpSharding 内部实现。