jax.tree_util.register_pytree_with_keys_class

jax.tree_util.register_pytree_with_keys_class#

jax.tree_util.register_pytree_with_keys_class(cls)[源代码]#

扩展了在 pytrees 中被视为内部节点的类型集合。

此函数类似于 register_pytree_node_class ,但需要一个定义了如何使用键进行扁平化的类。

它是 register_pytree_with_keys 的一个薄包装,并提供了一个面向类的接口:

参数:

cls (Typ) – 注册为 pytree 的类型

返回:

输入类 cls 在添加到 JAX 的 pytree 注册表后保持不变地返回。这个返回值使得 register_pytree_node_class 可以用作装饰器。

返回类型:

Typ

参见

示例

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> @register_pytree_with_keys_class
... class Special:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten_with_keys(self):
...     return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)