形状多态性#

当在JIT模式下使用JAX时,一个函数将被跟踪、降低到StableHLO,并为每种输入类型和形状的组合进行编译。在导出一个函数并在另一个系统上反序列化后,我们不再有可用的Python源代码,因此无法重新跟踪和重新降低它。形状多态性是JAX导出的一项功能,允许某些导出函数用于整个输入形状族。这些函数在导出期间被跟踪和降低一次,Exported对象包含在许多具体输入形状上编译和执行函数所需的信息。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,如下例所示:

>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x):  # f: f32[a, b]
...   return jnp.concatenate([x, x], axis=1)

>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")

>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)

>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)

>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)

请注意,这些函数仍然会根据每次调用的具体输入形状重新编译。只有跟踪和降低过程会被保存。

在上述示例中,jax.export.symbolic_shape() 用于将符号形状的字符串表示解析为维度表达式对象(类型为 _DimExpr),这些对象可以替代整数常量来构造形状。维度表达式对象重载了大多数整数运算符,因此在大多数情况下,您可以将它们用作整数常量。更多详情请参见 使用维度变量进行计算

此外,我们提供了 jax.export.symbolic_args_specs(),可用于根据多态形状规范构建 jax.ShapeDtypeStruct 对象的 pytrees:

>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
...  return x + y

>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))

注意多态形状规范 "a, ..." 如何包含占位符 ...,该占位符从参数 (x, y) 的具体形状中填充。占位符 ... 表示 0 个或更多维度,而占位符 _ 表示一个维度。jax.export.symbolic_args_specs() 支持参数的 pytrees,这些参数用于填充 dtypes 和任何占位符。该函数将构造一个与传递给它的参数结构匹配的参数规范 pytree(jax.ShapeDtypeStruct)。在一种规范应适用于多个参数的情况下,多态形状规范可以是 pytree 前缀,如上例所示。参见 如何将可选参数匹配到参数

一些形状规格的示例:

  • ("(b, _, _)", None) 可以用于一个有两个参数的函数,第一个参数是一个带有批处理主导维度的3D数组,该维度应该是符号化的。第一个参数的其他维度和第二个参数的形状是根据实际参数专门化的。注意,如果第一个参数是3D数组的pytree,所有数组都有相同的主导维度但可能有不同的尾随维度,同样的规范也会有效。第二个参数的值None意味着该参数不是符号化的。同样地,可以使用...

  • ("(batch, ...)", "(batch,)") 指定两个参数的前导维度匹配,第一个参数的秩至少为1,第二个参数的秩为1。

形状多态性的正确性#

我们希望信任的是,当编译并执行适用于任何具体形状时,导出的程序能产生与原始 JAX 程序相同的结果。更准确地说:

对于任何 JAX 函数 f 和任何包含符号形状的参数规范 arg_spec,以及任何形状与 arg_spec 匹配的具体参数 arg

  • 如果 JAX 本地执行在具体参数上成功:res = f(arg)

  • 如果导出成功且包含符号形状:exp = export.export(f)(arg_spec)

  • 然后编译并运行导出将成功,并得到相同的结果:res == exp.call(arg)

理解 f(arg) 可以自由地重新调用 JAX 的追踪机制至关重要,实际上它确实为每个不同的具体 arg 形状这样做,而 exp.call(arg) 的执行不能再使用 JAX 追踪(这种执行可能发生在 f 的源代码不可用的环境中)。

确保这种正确性是困难的,在最困难的情况下,导出会失败。本章的其余部分描述了如何处理这些失败。

使用维度变量进行计算#

JAX 跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 将它们计算为涉及维度变量的符号维度表达式。维度变量代表大于或等于 1 的整数值。符号表达式可以表示对维度表达式和整数应用算术运算符(加、减、乘、整除、取模,包括 NumPy 变体 np.sumnp.prod 等)的结果(intnp.int 或任何可通过 operator.index 转换的内容)。这些符号维度随后可以用于 JAX 原语和 API 的形状参数中,例如在 jnp.reshapejnp.arange、切片索引等中。

例如,在以下展平二维数组的代码中,计算 x.shape[0] * x.shape[1] 将符号维度 4 * b 作为新形状计算:

>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)

可以通过 jnp.array(x.shape[0]) 甚至 jnp.array(x.shape) 将维度表达式显式转换为 JAX 数组。这些操作的结果可以作为常规 JAX 数组使用,但不能再用作形状中的维度。

>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)

>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))  
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].

当符号维度与 非整数 进行算术运算时,例如 floatnp.floatnp.ndarray 或 JAX 数组,它会自动转换为使用 jnp.array 的 JAX 数组。例如,在下面的函数中,所有 x.shape[0] 的出现都会隐式转换为 jnp.array(x.shape[0]),因为它们参与了与非整数标量或 JAX 数组的运算:

>>> exp = export.export(jax.jit(
...     lambda x: (5. + x.shape[0],
...                x.shape[0] - np.arange(5, dtype=jnp.int32),
...                x + x.shape[0] + jnp.sin(x.shape[0]))))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
 ShapedArray(int32[5]),
 ShapedArray(float32[b], weak_type=True))

>>> exp.call(jnp.ones((3,), jnp.int32))
 (Array(8., dtype=float32, weak_type=True),
  Array([ 3, 2, 1, 0, -1], dtype=int32),
  Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))

另一个典型的例子是在计算平均值时(观察 x.shape[0] 如何自动转换为 JAX 数组):

>>> exp = export.export(jax.jit(
...     lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)

形状多态性存在时的错误#

大多数 JAX 代码假设 JAX 数组的形状是整数元组,但在形状多态性中,某些维度可能是符号表达式。这可能导致多种错误。例如,我们可能会遇到常见的 JAX 形状检查错误:

>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
...     jax.ShapeDtypeStruct((v,), dtype=np.int32),
...     jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).

>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
...     jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).

我们可以通过指定参数的形状为 (v, v) 来修复上述的 matmul 示例。

符号尺寸的比较部分支持#

在 JAX 内部,有许多涉及形状的相等和不相等比较,例如,用于形状检查或甚至为某些原语选择实现。比较支持如下:

  • 相等性支持有一个注意事项:如果两个符号维度在所有维度变量的估值下表示相同的值,那么相等性评估为 True,例如,对于 b + b == 2*b;否则相等性评估为 False。有关此行为的重要的后果讨论,请参见 下文

  • 不等式总是等式的否定。

  • 不平等部分支持,类似于部分相等的方式。然而,在这种情况下,我们考虑维度变量严格取正整数。例如,b >= 1b >= 02 * a + b >= 3True,而 b >= 2a >= ba - b >= 0 是不确定的,会导致异常。

在比较操作无法解析为布尔值的情况下,我们会引发 InconclusiveDimensionOperation。例如,

import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

如果你确实遇到了 InconclusiveDimensionOperation,你可以尝试几种策略:

  • 如果你的代码使用了内置的 maxmin,或者是 np.maxnp.min,那么你可以将这些替换为 core.max_dimcore.min_dim,它们的效果是将不等式比较延迟到编译时,当形状已知时进行。

  • 尝试使用 core.max_dimcore.min_dim 重写条件语句,例如,将 d if d > 0 else 0 改写为 core.max_dim(d, 0)

  • 尝试重写代码,使其不那么依赖于维度应该是整数的事实,而是依赖于符号维度在大多数算术操作中可以像整数一样进行鸭子类型。例如,不要写 int(d) + 5,而是写 d + 5

  • 指定符号约束,如下所述。

用户指定的符号约束#

默认情况下,JAX 假设所有维度变量的取值范围大于或等于 1,并且它会尝试从这一点推导出其他简单的非等式,例如:

  • a + 2 >= 3,

  • a * 2 >= 1,

  • a + b + c >= 3,

  • a // 4 >= 0, a**2 >= 1, 等等。

如果你将符号形状规范更改为添加隐式约束以限制维度大小,你可以避免一些不等式比较失败。例如,

  • 你可以使用 2*b 作为尺寸,以约束其为偶数且大于或等于2。

  • 你可以使用 b + 15 作为维度来限制其至少为 16。例如,如果没有 + 15 部分,以下代码将会失败,因为 JAX 会希望验证切片大小最多与轴大小相同。

>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
...    jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))

这种隐式的符号约束用于决定比较,并在编译时进行检查,如下文所述。

你也可以指定 显式 符号约束:

>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
...                              constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
...    jax.ShapeDtypeStruct((a, b), dtype=np.int32))

约束与隐含约束一起形成一个合取。你可以指定 >=<=== 约束。目前,JAX 对符号约束的推理支持有限:

  • 你从变量大于或等于常数的形式的约束中获得最大收益。例如,从约束 a >= 16b >= 8 中,我们可以推断出 a + 2*b >= 32

  • 当约束涉及更复杂的表达式时,您获得的权力是有限的,例如,从 a >= b + 8 我们可以推断出 a - b >= 8 但不能推断出 a >= 9。我们未来可能会在这个领域有所改进。

  • 等式约束被视为重写规则:每当遇到 == 左侧的符号表达式时,它将被重写为右侧的表达式。例如,floordiv(a, b) == c 通过将所有 floordiv(a, b) 的出现替换为 c 来工作。等式约束的左侧顶层不得包含加法或减法。有效的左侧示例包括 a * b,或 4 * a,或 floordiv(a + c, b)

>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
...                                    constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c

>>> a * b * b
b*d + b*c

符号约束也可以帮助解决 JAX 推理机制中的限制。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3,即符号表达式 mod(b, 3),小于或等于轴大小,即 b。对于所有严格正值的 b,这恰好是正确的,但这不是 JAX 的符号比较规则可以证明的。因此,以下代码会引发错误:

from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

一种选择是限制代码仅在轴大小为 3 的倍数时工作(通过在形状中将 b 替换为 3*b)。然后,JAX 将能够将取模操作 mod(3*b, 3) 简化为 0。另一种选择是添加一个带有 JAX 试图证明的确切不确定不等式的符号约束:

>>> b, = export.symbolic_shape("b",
...                            constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))

就像隐式约束一样,显式符号约束在编译时也会被检查,使用与下文解释的相同机制。

符号维度范围#

符号约束存储在 αn jax.export.SymbolicScope 对象中,该对象是为每次调用 jax.export.symbolic_shapes() 隐式创建的。你必须小心不要混合使用不同作用域的符号表达式。例如,以下代码将失败,因为 a1a2 使用了不同的作用域(由不同的 jax.export.symbolic_shape() 调用创建):

>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))

>>> a1 + a2  
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected  scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
  a >= 8

从对 jax.export.symbolic_shape() 的单次调用产生的符号表达式共享一个作用域,并且可以在算术运算中混合使用。结果也将共享相同的作用域。

你可以重复使用作用域:

>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope)  # Reuse the scope of `a`

>>> a + b  # Allowed
b + a

你也可以显式地创建作用域:

>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d  # Allowed
d + c

JAX 跟踪使用部分由形状键控的缓存,打印相同的符号形状如果在不同的作用域中使用,将被视为不同。

相等比较的注意事项#

相等比较对于 b + 1 == bb == 0 返回 False(在这种情况下,可以确定维度变量的所有值的维度是不同的),但也包括 b == 1a == b。这是不合理的,我们应该引发 core.InconclusiveDimensionOperation,因为在某些估值下结果应该是 True,而在其他估值下应该是 False。我们选择使相等性完全化,从而允许不合理性,因为否则在哈希维度表达式或包含它们的对象(形状、core.AbstractValuecore.Jaxpr)时可能会出现哈希冲突导致的虚假错误。除了哈希错误外,相等性的部分语义还会导致以下表达式出错 b == a or b == bb in [a, b],即使我们改变比较的顺序可以避免错误。

形式的代码 if x.shape[0] != 1: raise NiceErrorMessage 在这种相等性处理下是合理的,但形式的代码 if x.shape[0] != 1: return 1 是不合理的。

维度变量必须可以从输入形状中解析#

目前,当导出对象被调用时,传递维度变量值的唯一方式是通过数组参数的形状间接传递。例如,b 的值可以从类型为 f32[b] 的第一个参数的形状中推断出来。这对于大多数用例来说效果良好,并且它反映了 JIT 函数的调用约定。

有时你可能希望导出一个由整数值参数化的函数,该值决定了程序中的一些形状。例如,我们可能希望导出下面定义的函数 my_top_k,该函数由 k 的值参数化,k 决定了结果的形状。以下尝试将导致错误,因为维度变量 k 无法从输入 x: i32[4, 10] 的形状中推导出来:

>>> def my_top_k(k, x):  # x: i32[4, 10], k <= 10
...   return lax.top_k(x, k)[0]  # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))

>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])

>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])

>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)  
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

在未来,我们可能会增加一种额外的机制来传递维度变量的值,除了通过输入形状隐式传递之外。同时,上述用例的解决方法是,将函数参数 k 替换为一个形状为 (0, k) 的数组,这样 k 可以从数组的输入形状中推导出来。第一个维度为 0 是为了确保整个数组为空,并且在调用导出的函数时不会产生性能损失。

>>> def my_top_k_with_dimensions(dimensions, x):  # dimensions: i32[0, k], x: i32[4, 10]
...   return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
...     jax.ShapeDtypeStruct((0, k), dtype=np.int32),
...     x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))

>>> exp.out_avals[0]
ShapedArray(int32[4,k])

>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

另一种可能出现错误的情况是,某些维度变量确实出现在输入形状中,但它们以 JAX 目前无法解决的非线性表达式出现:

>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
...    jax.ShapeDtypeStruct((a * a,), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].

形状断言错误#

JAX 假设维度变量的取值范围是严格正整数,并且在为具体输入形状编译代码时会检查这一假设。

例如,给定符号输入形状 (b, b, 2*d),JAX 将在使用实际参数 arg 调用时生成代码来检查以下断言:

  • arg.shape[0] >= 1

  • arg.shape[1] == arg.shape[0]

  • arg.shape[2] % 2 == 0

  • arg.shape[2] // 2 >= 1

例如,当我们对形状为 (3, 3, 5) 的参数调用导出时,我们得到的错误如下:

>>> def f(x):  # x: f32[b, b, 2*d]
...   return x
>>> exp = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))   
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
  args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.

这些错误在编译前的预处理步骤中出现。

符号维度的除法部分支持#

JAX 将尝试简化除法和取模操作,例如,(a * b + a) // (b + 1) == a(6 * a + 4) % 3 == 1。特别是,JAX 将处理以下情况:(a) 没有余数,或 (b) 除数是常数,在这种情况下可能会有一个常数余数。

例如,下面的代码在尝试计算 reshape 操作的推断维度时会导致除法错误:

>>> b, = export.symbolic_shape("b")
>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))
Traceback (most recent call last):
jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1).
The remainder mod(b, - 2) should be 0.

请注意,以下操作将会成功:

>>> b, = export.symbolic_shape("b")
>>> # We specify that the first dimension is a multiple of 4
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
...     jax.ShapeDtypeStruct((4*b,), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,2*b]),)

>>> # We specify that some other dimension is even
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
...     jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,15*b]),)

调试#

首先,查看 调试 文档。此外,您可以调试形状细化,这是在编译具有维度变量或多平台支持的模块时调用的。

如果在形状细化过程中出现错误,您可以设置 JAX_DUMP_IR_TO 环境变量以查看形状细化之前的 HLO 模块的转储(命名为 ..._before_refine_polymorphic_shapes.mlir)。此模块应已具有静态输入形状。

要启用形状优化所有阶段的日志记录,您可以在OSS中设置环境变量 TF_CPP_VMODULE=refine_polymorphic_shapes=3(在Google内部,您传递 --vmodule=refine_polymorphic_shapes=3):

# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3