JAX 类型注解路线图#

  • 作者: jakevdp

  • 日期: 2022年8月

背景#

Python 3.0 引入了可选的函数注解(PEP 3107),这些注解在 Python 3.5 发布前后被编码用于静态类型检查(PEP 484)。在某种程度上,类型注解和静态类型检查已经成为许多 Python 开发工作流程中不可或缺的一部分,为此我们在 JAX API 的多个地方添加了注解。JAX 中类型注解的现状有些零散,由于更基本的设计问题,添加更多注解的努力受到了阻碍。本文档试图总结这些问题,并为 JAX 中类型注解的目标和非目标生成路线图。

为什么我们需要这样一个路线图?更好/更全面的类型注解是用户(无论是内部还是外部)的常见请求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917PR #10322),他们希望改进JAX的类型注解:对于审查代码的JAX团队成员来说,这些贡献是否有益并不总是显而易见的,特别是当它们引入复杂的协议来解决JAX使用Python进行全面注解的固有挑战时。本文档详细介绍了JAX在包内类型注解方面的目标和建议。

为什么使用类型注解?#

有许多原因可能导致一个Python项目希望为其代码库添加注释;我们将在本文档中将其总结为一级、二级和三级。

一级:作为文档的注解#

当最初在 PEP 3107 中引入时,类型注解的部分动机是能够将它们用作函数参数类型和返回类型的简洁、内联文档。JAX 长期以来一直以这种方式使用注解;一个例子是创建类型名称别名为 Any 的常见模式。可以在 lax/slicing.py [source] 中找到一个例子:

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

出于静态类型检查的目的,数组类型注释中使用 Array = Any 对参数值没有任何约束(Any 等同于没有注释),但它确实为开发者提供了一种有用的代码内文档形式。

为了生成的文档,别名的名称会丢失(jax.lax.sliceHTML 文档 将操作数报告为类型 Any),因此文档的好处不会超出源代码(尽管我们可以启用一些 sphinx-autodoc 选项来改进这一点:参见 autodoc_type_aliases)。

这种类型注解的一个好处是,使用 Any 注解一个值永远不会出错,因此它将以文档的形式为开发者和用户提供实际的好处,而不会增加满足任何特定静态类型检查器更严格需求的复杂性。

第二级:智能自动补全的注释#

许多现代IDE利用类型注解作为智能代码补全系统的输入。一个例子是VSCode的Pylance扩展,它使用微软的pyright静态类型检查器作为VSCode的IntelliSense补全的信息来源。

这种类型检查需要比上面使用的简单别名更进一步;例如,知道 slice 函数返回一个名为 ArrayAny 别名并不会为代码补全引擎添加任何有用的信息。然而,如果我们用 DeviceArray 返回类型注释该函数,自动补全将知道如何填充结果的命名空间,从而能够在开发过程中提供更相关的自动补全建议。

JAX 已经开始在某些地方添加这种类型的注解;一个例子是 jax.random 包中的 jnp.ndarray 返回类型 [source]:

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray 是一个抽象基类,它前向声明了 JAX 数组的属性和方法(查看源码),因此 VSCode 中的 Pylance 可以在此函数的结果上提供完整的自动补全功能。以下是显示结果的截图:

VSCode Intellisense 截图

在自动完成字段中列出的是 ndarray 抽象类声明的所有方法和属性。我们将在下面进一步讨论为什么需要创建这个抽象类,而不是直接使用 DeviceArray 进行注释。

第三级:静态类型检查的注解#

如今,当人们考虑Python代码中类型注解的用途时,静态类型检查通常是首先想到的事情。虽然Python不会进行任何运行时类型检查,但存在几个成熟的静态类型检查工具,可以在CI测试套件中执行此操作。对于JAX来说,最重要的工具如下:

  • python/mypy 在开源Python世界中或多或少是标准。JAX目前在其Github Actions CI检查中对部分源文件运行mypy。

  • google/pytype 是Google的静态类型检查器,在Google内部依赖JAX的项目中经常使用这个工具。

  • microsoft/pyright 作为在 VSCode 中用于 Pylance 完成功能的静态类型检查器,非常重要。

完全静态类型检查是所有类型注解应用中最严格的,因为它会在类型注解不精确时暴露错误。一方面,这很好,因为静态类型分析可能会捕捉到错误的类型注解(例如,DeviceArray 方法在 jnp.ndarray 抽象类中缺失的情况)。

另一方面,这种严格性可能会使那些经常依赖鸭子类型而不是严格类型安全API的包中的类型检查过程变得非常脆弱。你目前在JAX代码库中会发现数百处代码注释,如 #type: ignore(用于mypy)或 #pytype: disable(用于pytype)。这些通常代表出现了类型问题的地方;它们可能是JAX类型注解中的不准确之处,或者是静态类型检查器在正确跟踪代码控制流方面的能力不足。偶尔,它们是由于pytype或mypy行为中的真实且微妙的错误。在极少数情况下,它们可能是因为JAX使用了Python模式,这些模式在Python的静态类型注解语法中难以甚至无法表达。

JAX 的类型注解挑战#

JAX 目前有混合了不同风格的类型注解,并且针对上述讨论的三种类型注解级别。部分原因在于,JAX 的源代码对 Python 的类型注解系统提出了许多独特的挑战。我们将在下面概述这些挑战。

挑战 1:pytype、mypy 和开发者摩擦#

JAX 目前面临的一个挑战是,包开发必须满足两个不同静态类型检查系统的约束,即 pytype(用于内部 CI 和内部 Google 项目)和 mypy(用于外部 CI 和外部依赖项)。尽管这两个类型检查器在行为上有广泛的交集,但每个都有其独特的边缘情况,这在 JAX 代码库中大量的 #type: ignore#pytype: disable 语句中得到了证明。

这会在开发中产生摩擦:内部贡献者可能会迭代直到测试通过,结果发现导出后他们的 pytype 批准的代码违反了 mypy 的检查。对于外部贡献者来说,情况往往相反:最近的例子是 #9596,它在通过 Google 内部的 pytype 检查后不得不回滚。每次我们将类型注释从第 1 级(到处都是 Any)移动到第 2 级或第 3 级(更严格的注释),都会增加这种令人沮丧的开发者体验的潜在可能性。

挑战 2:数组鸭子类型#

注释 JAX 代码的一个特殊挑战是其大量使用鸭子类型。函数输入标记为 Array 的类型通常可以是多种不同类型之一:JAX DeviceArray、NumPy np.ndarray、NumPy 标量、Python 标量、Python 序列、具有 __array__ 属性的对象、具有 __jax_array__ 属性的对象,或任何形式的 jax.Tracer。因此,简单的注释如 def func(x: DeviceArray) 将不足以应对,并且会导致许多有效用例的误报。这意味着 JAX 函数的类型注释将不会简短或简单,而是需要我们有效地开发一组类似于 numpy.typing 中的 JAX 特定类型扩展。

挑战3:转换和装饰器#

JAX 的 Python API 严重依赖于函数变换(jit()vmap()grad() 等),这种类型的 API 对静态类型分析提出了特别的挑战。装饰器的灵活注解在 mypy 包中一直是一个长期存在的问题,直到最近通过引入 ParamSpec 才得以解决,这在 PEP 612 中讨论并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,它不能依赖 Python 3.10 的功能,直到 2024 年中期之后。在此期间,可以使用协议作为部分解决方案(JAX 在 #9950 中为 jit 和其他方法添加了此功能),并且可以通过 typing_extensions 包使用 ParamSpec(原型在 #9999 中),尽管这目前暴露了 mypy 中的基本错误(参见 python/mypy#12593)。总而言之:目前尚不清楚在 Python 类型注解工具的当前限制下,JAX 的函数变换 API 是否可以适当地注解。

挑战4:数组注解缺乏粒度#

这里的另一个挑战是所有面向数组的Python API所共有的,并且已经成为JAX讨论的一部分已有数年(参见 #943)。类型注释涉及对象的Python类或类型,而在基于数组的语言中,类的属性通常更为重要。在NumPy、JAX和类似包的情况下,我们通常希望注释特定的数组形状和数据类型。

例如,jnp.linspace 函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了防止注解引发误报,我们必须允许这些参数为 任意 数组。另一个例子是 jax.random.choice 的第二个参数,当 shape=() 时,它必须具有 dtype=int。Python 计划通过可变类型泛型(参见 PEP 646,计划用于 Python 3.11)来启用这种粒度的类型注解,但与 ParamSpec 一样,对该功能的支持需要一段时间才能稳定。

在此期间,有一些第三方项目可能会有所帮助,特别是 google/jaxtyping,但这使用了非标准的注解,可能不适合用于注解核心 JAX 库本身。总的来说,数组类型粒度挑战的问题不如其他挑战那么严重,因为主要影响是数组类注解将不如它们本来可以的那样具体。

挑战 5:从 NumPy 继承的不精确 API#

JAX 的用户界面 API 大部分继承自 jax.numpy 子模块中的 NumPy。NumPy 的 API 是在静态类型检查成为 Python 语言的一部分之前开发的,并且遵循 Python 的历史建议,使用 鸭子类型/EAFP 编码风格,其中不鼓励在运行时进行严格的类型检查。作为一个具体的例子,考虑 numpy.tile() 函数,其定义如下:

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

这里的 意图reps 将包含一个 int 或一个 int 值序列,但 实现 允许 tup 为任何可迭代对象。在为这种鸭子类型代码添加注解时,我们可以采取两种方式之一:

  1. 我们可能会选择注释函数的 意图 ,这里可能是类似 reps: Union[int, Sequence[int]] 的内容。

  2. 相反,我们可能选择注释函数的 实现 ,这里可能看起来像 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]] ,其中 ConvertibleToInt 是一个特殊协议,涵盖了我们的函数将输入转换为整数的具体机制(即通过 __int____index____array__ 等)。还要注意,严格来说, Iterable 在这里是不够的,因为Python中有一些对象可以作为可迭代对象进行鸭子类型,但不满足针对 Iterable 的静态类型检查(即,通过 __getitem__ 而不是 __iter__ 进行迭代的对象。)

#1 的优点,即注释意图,在于这些注释在传达 API 合同时对用户更有用;而对于开发者来说,灵活性为必要时的重构留出了空间。缺点(特别是对于像 JAX 这样的渐进式类型 API)是,很可能存在运行正确但会被类型检查器标记为不正确的用户代码。对现有鸭子类型 API 进行渐进式类型化意味着当前的注释隐含为 Any,因此将其更改为更严格的类型可能会对用户表现为破坏性更改。

广义上讲,注释意图更好地服务于第1级类型检查,而注释实现更好地服务于第3级,而第2级则更像是一个混合体(在IDE中的注释方面,意图和实现都很重要)。

JAX 类型注解路线图#

在这个框架(第1/2/3级)和JAX特定的挑战下,我们可以开始制定在JAX项目中实现一致类型注释的路线图。

指导原则#

对于JAX类型注解,我们将遵循以下原则:

类型注解的用途#

我们希望能够尽可能支持完整的 一级、二级和三级 类型注解。特别是,这意味着我们应该对公共API函数的输入和输出都有严格的类型注解。

意图注释#

JAX 类型注解通常应指示 API 的 意图,而不是实现细节,以便注解能够有效地传达 API 的契约。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(例如,传递一个任意迭代器来代替注解为 Shape = Sequence[int] 的形状)。

输入应为宽松类型#

JAX 函数和方法的输入应尽可能合理地允许类型:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。同样,接受 dtype 的函数不需要 np.dtype 类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,如 np.float64jnp.float64。为了使这一点在整个包中尽可能统一,我们将添加一个 jax.typing 模块,其中包含常见的类型规范,从广泛的类别开始,例如:

  • ArrayLike 将是一个可以隐式转换为数组的对象的联合类型:例如,jax 数组、numpy 数组、JAX 跟踪器,以及 python 或 numpy 标量。

  • DTypeLike 是任何可以隐式转换为 dtype 的联合类型:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。

  • ShapeLike 可以是任何可以转换为形状的联合类型:例如,整数或类似整数对象的序列。

  • 等等。

请注意,这些通常会比 numpy.typing 中使用的等效协议更简单。例如,在 DTypeLike 的情况下,JAX 不支持结构化 dtypes,因此 JAX 可以使用更简单的实现。同样,在 ArrayLike 中,JAX 通常不支持使用列表或元组代替数组,因此类型定义将比 NumPy 的类似定义更简单。

输出应为严格类型化#

相反,函数和方法的输出应尽可能严格地类型化:例如,对于返回数组的 JAX 函数,输出应使用类似于 jnp.ndarray 的类型进行注解,而不是 ArrayLike。返回 dtype 的函数应始终注解为 np.dtype,而返回形状的函数应始终为 Tuple[int] 或严格类型化的 NamedShape 等效类型。为此,我们将在 jax.typing 中实现上述宽容类型的几个严格类型化模拟,即:

  • ArrayNDArray(见下文)用于类型注解时,实际上等同于 Union[Tracer, jnp.ndarray],应使用它来注解数组输出。

  • DTypenp.dtype 的别名,可能还具有表示 JAX 内部使用的键类型和其他泛化的能力。

  • Shape 本质上是一个 Tuple[int, ...],可能会有一些额外的灵活性来适应动态形状。

  • NamedShapeShape 的扩展,允许在 JAX 内部使用命名形状。

  • 等等。

我们还将探讨是否可以放弃 jax.numpy.ndarray 的当前实现,转而将 ndarray 作为 Array 或类似的别名。

倾向于简单#

除了在 jax.typing 中收集的常见类型协议外,我们应该倾向于简单性。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而是在无法简洁地指定 API 的完整类型规范的情况下,使用简单的联合类型,如 Union[simple_type, Any]。这是一种在避免不必要复杂性的同时,实现第 1 级和第 2 级注解目标的折中方案。

避免不稳定的输入机制#

为了避免因内部/外部CI差异而增加不必要的开发阻力,我们希望在使用类型注解构造时保持保守:特别是对于最近引入的机制,如 ParamSpecPEP 612)和可变类型泛型(PEP 646),我们希望等到mypy和其他工具的支持成熟和稳定后再依赖它们。

这一影响的一个结果是,目前,当函数被 JAX 转换如 jitvmapgrad 等装饰时,JAX 实际上会 剥离所有注解 从被装饰的函数中。虽然这很不幸,但在撰写本文时,mypy 与 ParamSpec 提供的潜在解决方案存在一系列不兼容问题(参见 ParamSpec mypy 错误追踪),因此我们认为目前它还不适合在 JAX 中全面采用。一旦对这些功能的支持稳定下来,我们将在未来重新审视这个问题。

同样地,目前我们将避免添加 jaxtyping 项目提供的更复杂和细粒度的数组类型注解。这是一个我们未来可以重新考虑的决定。

数组 类型设计考虑#

如上所述,JAX 中数组的类型注释提出了一个独特的挑战,这是由于 JAX 广泛使用鸭子类型,即在 jax 变换中传递和返回 Tracer 对象来代替实际的数组。这变得越来越令人困惑,因为用于类型注释的对象通常与用于运行时实例检查的对象重叠,并且可能或可能不对应于所讨论对象的实际类型层次结构。对于 JAX,我们需要为以下两种情况提供鸭子类型的对象:静态类型注释运行时实例检查

以下讨论将假设 jax.Array 是设备上数组的运行时类型,尽管目前尚未实现,但一旦 #12016 的工作完成,就会成为现实。

静态类型注解#

我们需要提供一个可以用于鸭子类型注解的对象。假设我们暂时称这个对象为 ArrayAnnotation,我们需要一个解决方案,使得 mypypytype 能够满足以下情况:

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法实现,例如:

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,声明 TracerArray 应被视为 ArrayAnnotation 的子类。

  • 重构 ArrayTracer ,使得 ArrayAnnotation 成为两者的真正基类。

运行时实例检查#

我们还必须提供一个可以在鸭子类型运行时 isinstance 检查中使用的对象。假设我们暂时称这个对象为 ArrayInstance,我们需要一个能通过以下运行时检查的解决方案:

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

再次,有几种机制可以用于此:

  • 覆盖 type(ArrayInstance).__instancecheck__ 以对 ArrayTracer 对象都返回 True;这是 jnp.ndarray 当前的实现方式(来源)。

  • ArrayInstance 定义为一个抽象基类,并动态注册到 ArrayTracer

  • 重构 ArrayTracer 以便 ArrayInstance 成为 ArrayTracer 的真正基类

我们需要做出的一个决定是 ArrayAnnotationArrayInstance 应该是相同还是不同的对象。这里有一些先例;例如在核心 Python 语言规范中,typing.Dicttyping.List 存在是为了注解,而内置的 dictlist 用于实例检查。然而,DictList 在较新的 Python 版本中已被 弃用,转而支持使用 dictlist 进行注解和实例检查。

跟随 NumPy 的脚步#

在 NumPy 的情况下,np.typing.NDArray 用于类型注解,而 np.ndarray 用于实例检查(以及数组类型标识)。鉴于这一点,遵循 NumPy 的先例并实现以下内容可能是合理的:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.NDArray 是用于鸭子类型数组注释的对象。

  • jax.numpy.ndarray 是用于鸭子类型数组实例检查的对象。

对于NumPy的高级用户来说,这可能会感觉有些自然,然而这种三分法很可能会造成混淆:例如,选择哪种方式来进行实例检查和注解并不立即明确。

统一实例检查和注解#

另一种方法是通过上述提到的重写机制来统一类型检查和注解。

选项 1: 部分统一#

部分统一可能看起来像这样:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.Array 是用于鸭子类型数组注解的对象(通过 ArrayTracer 上的 .pyi 接口)。

  • jax.typing.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 覆盖)

在这种方法中,jax.numpy.ndarray 将成为一个简单的别名 jax.typing.Array,以保持向后兼容性。

选项 2:通过覆盖实现完全统一#

或者,我们可以通过覆盖实现完全统一:

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于鸭子类型数组注解的对象(通过 Tracer 上的 .pyi 接口)

  • jax.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 重写)

在这里,jax.numpy.ndarray 将成为一个简单的别名 jax.Array 以实现向后兼容。

选项 3:通过类层次结构实现完全统一#

最后,我们可以选择通过重构类层次结构并使用面向对象的对象层次结构替换鸭子类型来实现完全统一:

  • jax.Array 是设备上数组的实际类型

  • jax.Array 也是用于数组类型注释的对象,通过确保 Tracer 继承自 jax.Array

  • jax.Array 也是通过相同机制用于实例检查的对象

这里 jnp.ndarray 可能是 jax.Array 的别名。这种最终方法在某些方面是最纯粹的,但从面向对象设计的角度来看,它有些勉强(Tracer Array?)。

选项 4:通过类层次结构进行部分统一#

我们可以通过让 Tracer 和设备数组类继承自一个共同的基类,来使类层次结构更加合理。例如:

  • jax.ArrayTracer 的基类,也是设备上数组的实际类型,可能是 jax._src.ArrayImpl 或类似类型。

  • jax.Array 是用于数组类型注释的对象

  • jax.Array 也是用于实例检查的对象

这里 jnp.ndarray 将是 Array 的别名。从面向对象的角度来看,这可能更纯粹,但与选项2和3相比,它放弃了 type(x) is jax.Array 将评估为 True 的概念。

评估#

考虑到每种潜在方法的总体优缺点:

  • 从用户的角度来看,统一的方法(选项2和3)可以说是最好的,因为它们消除了记住用于实例检查或注释的对象的认知开销:jax.Array 是你需要知道的一切。

  • 然而,选项2和选项3都引入了一些奇怪和/或令人困惑的行为。选项2依赖于可能令人困惑的实例检查覆盖,这些覆盖对于在pybind11中定义的类支持不佳。选项3要求Tracer成为数组的子类。这打破了继承模型,因为它将要求Tracer对象携带Array对象的所有负担(数据缓冲区、分片、设备等)。

  • 选项4在面向对象编程的意义上更纯粹,并且避免了任何典型实例检查或类型注释行为的覆盖需求。权衡之处在于设备上数组的实际类型变成了某种独立的东西(这里为 jax._src.ArrayImpl)。但绝大多数用户永远不需要直接接触这个私有实现。

这里存在不同的权衡,但在讨论之后,我们决定选择选项4作为我们前进的方向。

实施计划#

为了推进类型注解,我们将执行以下操作:

  1. 迭代此 JEP 文档,直到开发者和利益相关者都认同。

  2. 创建一个私有的 jax._src.typing (目前不提供任何公共API),并将上述提到的简单类型的第一层放入其中:

    • 暂时使用别名 Array = Any ,因为这需要更多的思考。

    • ArrayLike: 作为普通 jax.numpy 函数输入的有效类型的联合

    • DType / DTypeLike (注意:numpy 使用驼峰式的 DType;为了便于使用,我们应该遵循这一约定)

    • Shape / NamedShape / ShapeLike

    这方面的开始在 #12300 中完成。

  3. 开始着手于一个遵循上一节选项4的 jax.Array 基类。最初这将用Python定义,并使用当前在 jnp.ndarray 实现中找到的动态注册机制,以确保 isinstance 检查的正确行为。每个跟踪器和类似数组的类的 pyi 覆盖将确保类型注释的正确行为。然后可以将 jnp.ndarray 设为 jax.Array 的别名。

  4. 作为测试,根据上述指南,使用这些新的类型定义全面注释 jax.lax 中的函数。

  5. 继续逐个模块添加额外的注释,重点关注公共API函数。

  6. 同时,开始在 pybind11 中重新实现一个 jax.Array 基类,以便 ArrayImplTracer 可以继承它。使用 pyi 定义以确保静态类型检查器识别该类的适当属性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全落地,移除这些临时的 Python 实现。

  8. 当一切准备就绪时,创建一个公开的 jax.typing 模块,使上述类型对用户可用,并附带使用 JAX 代码的注释最佳实践文档。

我们将在 #12049 中跟踪这项工作,这个 JEP 的编号由此而来。