jax.tree_util.register_static#
- jax.tree_util.register_static(cls)[源代码]#
将 cls 注册为没有叶子的 pytree。
实例被
jax.jit()
、jax.pmap()
等视为静态。这可以作为使用jit
的static_argnums
和static_argnames
kwargs、pmap
的static_broadcasted_argnums
等将输入标记为静态的替代方法。- 参数:
cls (type[H]) – 要注册为静态的类型。必须可哈希,如 https://docs.python.org/3/glossary.html#term-hashable 中所定义。
- 返回:
输入类
cls
在添加到 JAX 的 pytree 注册表后保持不变。这使得register_static
可以用作装饰器。- 返回类型:
type[H]
示例
>>> import jax >>> @jax.tree_util.register_static ... class StaticStr(str): ... pass
这个静态字符串现在可以直接在
jax.jit()
编译的函数中使用,而无需使用static_argnums
标记变量为静态:>>> @jax.jit ... def f(x, y, s): ... return x + y if s == 'add' else x - y ... >>> f(1, 2, StaticStr('add')) Array(3, dtype=int32, weak_type=True)