jax.tree_util.register_pytree_node_class

jax.tree_util.register_pytree_node_class#

jax.tree_util.register_pytree_node_class(cls)[源代码]#

扩展了在 pytrees 中被视为内部节点的类型集合。

此函数是 register_pytree_node 的一个薄包装,并提供了一个面向类的接口。

参数:

cls (Typ) – 注册为 pytree 的类型

返回:

输入类 cls 在添加到 JAX 的 pytree 注册表后保持不变地返回。这个返回值使得 register_pytree_node_class 可以用作装饰器。

返回类型:

Typ

参见

示例

在这里,我们将定义一个自定义容器,该容器将与 jax.jit() 和其他 JAX 转换兼容:

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)