jax.tree_util.register_static

jax.tree_util.register_static#

jax.tree_util.register_static(cls)[源代码]#

cls 注册为没有叶子的 pytree。

实例被 jax.jit()jax.pmap() 等视为静态。这可以作为使用 jitstatic_argnumsstatic_argnames kwargs、pmapstatic_broadcasted_argnums 等将输入标记为静态的替代方法。

参数:

cls (type[H]) – 要注册为静态的类型。必须可哈希,如 https://docs.python.org/3/glossary.html#term-hashable 中所定义。

返回:

输入类 cls 在添加到 JAX 的 pytree 注册表后保持不变。这使得 register_static 可以用作装饰器。

返回类型:

type[H]

示例

>>> import jax
>>> @jax.tree_util.register_static
... class StaticStr(str):
...   pass

这个静态字符串现在可以直接在 jax.jit() 编译的函数中使用,而无需使用 static_argnums 标记变量为静态:

>>> @jax.jit
... def f(x, y, s):
...   return x + y if s == 'add' else x - y
...
>>> f(1, 2, StaticStr('add'))
Array(3, dtype=int32, weak_type=True)