JEP 18137: JAX NumPy 和 SciPy 包装器的范围#
Jake VanderPlas
2023年10月
到目前为止,jax.numpy
和 jax.scipy
的预期范围相对不明确。本文档为这些包提出了一个明确的范围,以更好地指导和评估未来的贡献,并激励移除一些超出范围的代码。
背景#
从一开始,JAX 的目标就是为在 XLA 中执行代码提供一个类似 NumPy 的 API,项目开发的一个重要部分是构建 jax.numpy
和 jax.scipy
命名空间,作为基于 JAX 的 NumPy 和 SciPy API 实现。一直以来都有一个隐含的理解,即 numpy
和 scipy
的某些部分不在 JAX 的范围内,但这个范围并没有很好地定义。这可能会导致贡献者的困惑和挫折,因为没有明确的答案来判断潜在的 jax.numpy
和 jax.scipy
贡献是否会被 JAX 接受。
为什么要限制范围?#
为了避免遗漏这一点,我们应该明确指出:任何被包含在像 JAX 这样的项目中的代码,都会给开发者带来一个虽小但非零的持续维护负担。一个项目随着时间的推移能否成功,直接关系到维护者能否继续为项目的所有部分进行维护:记录功能、回答问题、修复错误等。对于任何软件工具的长期成功和可持续性,维护者仔细权衡任何特定贡献是否会在考虑到项目目标和资源的情况下,对项目产生净正面影响,是至关重要的。
评估标准#
本文档提出了六个轴,用于评估任何特定的 numpy
或 scipy
API 是否应纳入 JAX。在所有轴上表现强劲的 API 是纳入 JAX 包的优秀候选者;在六个轴中的任何一个上表现出的明显弱点都是反对纳入 JAX 的好理由。
轴 1: XLA 对齐#
我们考虑的第一个轴是所提出的API与原生XLA操作的对齐程度。例如,jax.numpy.exp()
函数基本上直接反映了 jax.lax.exp
。numpy
、scipy.special
、numpy.linalg
、scipy.linalg
等模块中的大量函数都符合这一标准:在考虑将这些函数纳入JAX时,它们通过了XLA对齐检查。
另一方面,有一些函数,如 numpy.unique()
,它们并不直接对应于任何 XLA 操作,并且在某些情况下,与 JAX 当前的计算模型根本不兼容,该模型要求数组具有静态形状(例如,unique
返回依赖于值的动态数组形状)。在考虑将这些函数纳入 JAX 时,它们不会通过 XLA 对齐检查。
我们还考虑了纯函数语义的需求,作为此轴的一部分。例如,numpy.random
是基于隐式更新的基于状态的随机数生成器构建的,这与基于 XLA 构建的 JAX 计算模型从根本上不兼容。
轴 2: 数组 API 对齐#
我们考虑的第二个轴心集中在 Python Array API 标准:在某种意义上,这是一个社区驱动的纲要,概述了哪些数组操作是面向数组编程的核心,涵盖了广泛的用户社区。如果 numpy
或 scipy
中的某个 API 被列在 Array API 标准中,这是一个强烈的信号,表明 JAX 应该包含它。以上面的例子为例,Array API 标准包括 numpy.unique()
的几个变体(unique_all
、unique_counts
、unique_inverse
、unique_values
),这表明尽管该函数与 XLA 不完全对齐,但它对 Python 用户社区的重要性足以让 JAX 或许应该实现它。
轴 3: 下游实现的存在#
对于不符合轴1或轴2的功能,纳入JAX的一个重要考虑因素是是否存在提供该功能的良好支持的下游包。这方面的一个很好的例子是scipy.optimize
:虽然JAX确实包含了一组最小的scipy.optimize
功能包装器,但在JAXopt包中存在一个更完整的处理,该包由JAX合作者积极维护。在这种情况下,我们应该倾向于引导用户和贡献者使用这些专门的包,而不是在JAX本身中重新实现这些API。
轴 4:实现的复杂性与鲁棒性#
对于与XLA不一致的功能,一个考虑因素是所提议实现的程度复杂性。这在某种程度上与轴1一致,但无论如何都值得指出。已经向JAX贡献了许多函数,这些函数的实现相对复杂,难以验证,并引入了过大的维护负担;一个例子是jax.scipy.special.bessel_jn()
:截至本JEP的撰写,其当前实现是一个非直接的迭代近似,在某些领域存在收敛问题,而提出的修复引入了进一步的复杂性。如果我们更仔细地权衡了实现的复杂性和鲁棒性,在接受贡献时,我们可能不会选择接受这个贡献到包中。
轴 5:功能性 vs. 面向对象的 API#
JAX 最适合与函数式 API 而不是面向对象的 API 一起使用。面向对象的 API 通常会隐藏不纯的语义,使得它们通常难以很好地实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便利包装器。
这方面的一个例子是 numpy.polynomial.Polynomial
,它封装了较低级别的操作,如 numpy.polyadd()
、numpy.polydiv()
等。一般来说,当同时存在功能性和面向对象的API时,JAX 应该避免为面向对象的API提供封装,而是为功能性API提供封装。
在仅存在面向对象API的情况下,JAX应避免提供包装器,除非在其他方面有强有力的理由。
轴 6: 对 JAX 用户和利益相关者的“重要性”#
在JAX中包含NumPy/SciPy API的决定也应考虑到该算法对一般用户社区的重要性。诚然,量化谁是“利益相关者”以及这种重要性应如何衡量是困难的;但我们包括这一点是为了明确指出,关于在JAX的NumPy和SciPy包装器中包含什么内容的任何决定都将涉及一定程度的判断,这种判断无法轻易量化。
对于现有的API,在github中搜索其使用情况可能有助于确定其重要性或无关紧要;例如,我们可以回到上面讨论的 jax.scipy.special.bessel_jn()
:搜索显示,这个函数在github上只有 少数几个使用,这可能部分是由于之前提到的精度问题。
评估:范围是什么?#
在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy 的 API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的全面列表,而是按子模块和主题进行更一般的讨论,并附上相关示例。
NumPy API#
✅ numpy
命名空间#
我们认为主 numpy
命名空间中的函数基本上都在 JAX 的范围内,这是由于它与 XLA(轴 1)和 Python Array API(轴 2)的总体一致性,以及它对 JAX 用户社区的普遍重要性(轴 6)。一些函数可能是边缘情况(例如 numpy.intersect1d()
、np.setdiff1d()
、np.union1d()
可能不符合部分标准),但为了简单起见,我们声明主 numpy 命名空间中的所有数组函数都在 JAX 的范围内。
✅ numpy.linalg
& numpy.fft
#
The numpy.linalg
和 numpy.fft
子模块包含许多与XLA提供的功能广泛对齐的函数。其他的则有复杂的设备特定降低,但代表了利益相关者(轴6)的重要性超过复杂性的情况。因此,我们认为这两个子模块都在JAX的范围内。
❌ numpy.random
#
numpy.random
不在 JAX 的范围内,因为基于状态的随机数生成器与 JAX 的计算模型根本不兼容。我们转而关注 jax.random
,它使用基于计数器的伪随机数生成器提供类似的功能。
❌ numpy.ma
& numpy.polynomial
#
The numpy.ma
和 numpy.polynomial
子模块主要关注通过面向对象接口提供可以通过其他功能手段表达的计算(轴 5);因此,我们认为它们不在 JAX 的范围内。
❌ numpy.testing
#
NumPy 的测试功能实际上只对主机端计算有意义,因此我们在 JAX 中不包含任何针对它的包装器。尽管如此,JAX 数组与 numpy.testing
兼容,并且在 JAX 测试套件中经常使用它。
SciPy API#
SciPy 在顶层命名空间中没有函数,但包含许多子模块。我们下面逐一考虑这些子模块,省略已弃用的模块。
❌ scipy.cluster
#
The scipy.cluster
模块包含用于层次聚类、k-means 及相关算法的工具。这些在几个方面表现较弱,更适合由下游包来实现。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()
),但在 github 上没有明显的使用:这表明聚类对 JAX 用户来说并不广泛重要。
建议:弃用并移除 jax.scipy.cluster.vq()
。
❌ scipy.constants
#
The scipy.constants
模块包含数学和物理常数。这些常数可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现这一点。
❌ scipy.datasets
#
The scipy.datasets
模块包含获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现这一点。
✅ scipy.fft
#
The scipy.fft
模块包含的功能与 XLA 提供的功能大致一致,并且在其他方面也表现良好。因此,我们认为它们在 JAX 的范围内。
❌ scipy.integrate
#
The scipy.integrate
模块包含用于数值积分的函数。其中较为复杂的函数(如 quad
, dblquad
, ode
)由于基于动态评估次数的循环算法,超出了 JAX 的轴 1 和轴 4 的范围。jax.experimental.ode.odeint()
与此相关,但功能较为有限且不在任何积极开发中。
JAX 目前包含 jax.scipy.integrate.trapezoid()
,但这只是因为 numpy.trapz()
最近被弃用,转而支持这个。对于任何特定的输入,它的实现可以用一行 jax.numpy
表达式替换,因此它并不是一个特别有用的 API 来提供。
基于轴1、2、4和6,scipy.integrate
应被视为JAX的范围之外。
建议:移除 jax.scipy.integrate.trapezoid()
,该功能在 JAX 0.4.14 中添加。
❌ scipy.interpolate
#
The scipy.interpolate
模块提供了用于在一维或多维中插值的低级和面向对象的例程。这些API在上述的多个轴上评分较低:它们是基于类的而不是低级的,并且除了最简单的方法外,其他方法都不能有效地用XLA操作来表示。
JAX 目前确实有 scipy.interpolate.RegularGridInterpolator
的包装器。如果今天我们考虑这个贡献,我们可能会根据上述标准拒绝它。但由于这段代码相当稳定,因此继续维护它并没有太大的弊端。
今后,我们应该考虑 scipy.interpolate
的其他成员不在 JAX 的范围内。
❌ scipy.io
#
The scipy.io
子模块与文件输入/输出有关。没有理由在 JAX 中重新实现这一点。
✅ scipy.linalg
#
The scipy.linalg
子模块包含的功能与XLA提供的大致一致,而快速线性代数对JAX用户社区来说非常重要。因此,我们认为它属于JAX的范畴。
❌ scipy.ndimage
#
The scipy.ndimage
子模块包含一组用于处理图像数据的工具。其中许多工具与 scipy.signal
中的工具重叠(例如卷积和滤波)。JAX 目前提供了一个 scipy.ndimage
API,即 jax.scipy.ndimage.map_coordinates()
。此外,JAX 在 jax.image
模块中提供了一些与图像相关的工具。deepmind 生态系统包括 dm-pix,这是 JAX 中一组功能更全面的图像处理工具。考虑到所有这些因素,我建议 scipy.ndimage
应被视为 JAX 核心范围之外的内容;我们可以将感兴趣的用户和贡献者指向 dm-pix。我们可以考虑将 map_coordinates
移动到 dm-pix
或其他适当的包中。
❌ scipy.odr
#
The scipy.odr
模块为执行正交距离回归提供了 ODRPACK
的对象导向封装。目前尚不清楚这能否使用现有的 JAX 原语清晰地表达,因此我们认为它超出了 JAX 本身的范围。
❌ scipy.optimize
#
The scipy.optimize
模块提供了高级和低级的优化接口。这种功能对许多 JAX 用户来说非常重要,JAX 很早就创建了 jax.scipy.optimize
包装器。然而,这些例程的开发者很快意识到 scipy.optimize
API 过于受限,不同的团队开始致力于 JAXopt 包和 Optimistix 包,每个包都包含一套更全面、测试更好的 JAX 优化例程。
由于这些得到良好支持的外部包,我们现在认为 scipy.optimize
超出了 JAX 的范围。
建议:弃用 jax.scipy.optimize
或将其作为可选的 JAXopt 或 Optimistix 依赖项的轻量级包装器。
🟡 scipy.signal
#
The scipy.signal
模块是混合的:一些函数完全在 JAX 的范围内(例如 correlate
和 convolve
,它们是 lax.conv_general_dilated
的用户友好包装器),而许多其他函数则完全超出范围(特定领域的工具,没有可行的路径降低到 XLA)。对 jax.scipy.signal
的潜在贡献将需要逐案权衡。
🟡 scipy.sparse
#
The scipy.sparse
子模块主要包含用于存储和操作各种格式稀疏矩阵和数组的数据结构。此外,scipy.sparse.linalg
包含许多适用于稀疏矩阵、密集矩阵和线性算子的无矩阵求解器。
JAX 的计算模型(例如,许多操作依赖于动态大小的缓冲区)与 scipy.sparse
数组和矩阵数据结构不一致,因此它们不在 JAX 的范围内。JAX 开发了 jax.experimental.sparse
模块作为一组更符合 JAX 计算约束的数据结构。因此,我们认为 scipy.sparse
中的数据结构不在 JAX 的范围内。
另一方面,scipy.sparse.linalg
已被证明是一个有趣的领域,而 jax.scipy.sparse.linalg
包含了 bicgstab
、cg
和 gmres
求解器。这些对 JAX 用户社区(轴 6)很有用,但在其他轴上表现不佳。它们非常适合迁移到下游库中;一个潜在的选择可能是 Lineax,它基于 JAX 构建了多个线性求解器。
建议:探索将稀疏求解器移入 Lineax,并将 scipy.sparse
视为 JAX 的范围之外。
❌ scipy.spatial
#
The scipy.spatial
模块主要包含面向对象的接口,用于空间/距离计算和最近邻搜索。对于 JAX 来说,这大部分是超出范围的。
The scipy.spatial.transform
子模块提供了操作三维空间旋转的工具。它是一个相对复杂的面向对象接口,可能更适合由下游项目来实现。JAX 目前包含 jax.scipy.spatial.transform
中的 Rotation
和 Slerp
的部分实现;这些是面向对象的包装器,引入了非常大的 API 接口,但用户非常少。我们认为它们超出了 JAX 本身的范围,用户更适合由一个假设的下游项目来服务。
The scipy.spatial.distance
子模块包含了一系列有用的距离度量方法,为这些方法提供 JAX 包装器可能很诱人。然而,通过使用 jit 和 vmap,用户在需要时可以轻松地从头定义这些方法的高效版本,因此将它们添加到 JAX 中并没有特别的好处。
建议:考虑弃用并移除 Rotation
和 Slerp
API,并考虑将 scipy.spatial
作为一个整体排除在未来贡献范围之外。
✅ scipy.special
#
The scipy.special
模块包含了许多更专门化函数的实现。在许多情况下,这些函数都在范围内:例如,像 gammaln
、betainc
、digamma
等函数直接对应于可用的 XLA 原语,并且根据 Axis 1 和其他标准显然在范围内。
其他功能需要更复杂的实现;上述提到的例子之一是 bessel_jn
。尽管在轴1和轴2上不一致,这些函数在轴6上往往非常强大:scipy.special
提供了各种领域计算所需的基本函数,因此即使实现复杂的函数也应倾向于在范围内,只要实现设计良好且稳健。
有几种现有的函数包装器我们应该仔细研究一下;例如:
jax.scipy.special.lpmn()
: 这个函数通过一个复杂的 fori_loop 生成勒让德多项式,这种方式与 scipy API 不匹配(例如,对于scipy
,z
必须是一个标量,而对于 JAX,z
必须是一个一维数组)。该函数很少有可发现的用途,因此在维度 1、2、4 和 6 上是一个较弱的候选者。jax.scipy.special.lpmn_values()
:这与上述lmpn
有类似的弱点。jax.scipy.special.sph_harm()
:这是基于 lpmn 构建的,并且同样具有与相应scipy
函数不同的 API。jax.scipy.special.bessel_jn()
: 如上文在第4轴讨论的,它在实现稳健性和使用率方面存在弱点。我们可能会考虑用一个新的、更稳健的实现来替代它(例如 #17038)。
*建议:重构并提高 bessel_jn
的鲁棒性和测试覆盖率。如果 lpmn
、lpmn_values
和 sph_harm
不能被修改以更接近 scipy
API,考虑弃用它们。
✅ scipy.stats
#
The scipy.stats
模块包含广泛的统计函数,包括离散和连续分布、汇总统计和假设检验。JAX 目前将其中的一些封装在 jax.scipy.stats
中,主要包括大约 20 种统计分布,以及其他一些函数(mode
, rankdata
, gaussian_kde
)。通常这些与 JAX 非常契合:分布通常可以用高效的 XLA 操作来表达,API 简洁且功能性强。
我们目前没有任何假设检验函数的封装,可能是因为这些对JAX的主要用户群体来说用处不大。
关于分布,在某些情况下,tensorflow_probability
提供了类似的功能,未来我们可能会考虑是否要弃用 scipy.stats 的分布,转而支持该实现。
建议:今后,我们应该将统计分布和汇总统计视为范围内,并考虑假设检验及相关功能通常为范围外。