jax.tree_util.register_pytree_node

jax.tree_util.register_pytree_node#

jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)[源代码]#

扩展了在 pytrees 中被视为内部节点的类型集合。

参见 示例用法

参数:
  • nodetype (type[T]) – 一个注册为 pytree 的 Python 类型。

  • flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – 在展平过程中使用的一个函数,接受类型为 nodetype 的值并返回一对,其中 (1) 是用于递归展平的子节点的可迭代对象,(2) 是一些可哈希的辅助数据,存储在 treedef 中并传递给 unflatten_func

  • unflatten_func (Callable[[_AuxData, _Children], T]) – 一个接受两个参数的函数:由 flatten_func 返回并存储在 treedef 中的辅助数据,以及未扁平化的子节点。该函数应返回 nodetype 的一个实例。

返回类型:

None

参见

示例

首先,我们将定义一个自定义类型:

>>> class MyContainer:
...   def __init__(self, size):
...     self.x = jnp.zeros(size)
...     self.y = jnp.ones(size)
...     self.size = size

如果我们在即时编译(JIT)函数中尝试使用这个类型,我们会得到一个错误,因为JAX还不知道如何处理这种类型:

>>> m = MyContainer(size=5)
>>> def f(m):
...   return m.x + m.y + jnp.arange(m.size)
>>> jax.jit(f)(m)  
Traceback (most recent call last):
  ...
TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute

为了使我们的对象被 JAX 识别,我们必须将其注册为一个 pytree:

>>> def flatten_func(obj):
...   children = (obj.x, obj.y)  # children must contain arrays & pytrees
...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
...   return (children, aux_data)
...
>>> def unflatten_func(aux_data, children):
...   # Here we avoid `__init__` because it has extra logic we don't require:
...   obj = object.__new__(MyContainer)
...   obj.x, obj.y = children
...   obj.size, = aux_data
...   return obj
...
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)

现在定义了这个之后,我们可以在即时编译的函数中使用这种类型的实例。

>>> jax.jit(f)(m)
Array([1., 2., 3., 4., 5.], dtype=float32)