jax.tree_util.注册数据类

jax.tree_util.注册数据类#

jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields, drop_fields=())[源代码]#

扩展了在 pytrees 中被视为内部节点的类型集合。

这与 register_pytree_with_keys_class 不同,因为 C++ 注册表使用优化的 C++ dataclass 内置而不是参数函数。

有关注册 pytrees 的更多信息,请参阅 扩展 pytrees

参数:
  • nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a dataclass: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among meta_fields or data_fields.

  • meta_fields (Sequence[str]) – 辅助数据字段名称。这些字段 必须 包含静态、可哈希、不可变对象,因为这些对象用于生成 JIT 缓存键。特别是,meta_fields 不能包含 jax.Arraynumpy.ndarray 对象。

  • data_fields (Sequence[str]) – 数据字段名称。这些字段 必须 是 JAX 兼容的对象,例如数组(jax.Arraynumpy.ndarray)、标量或其叶子为数组或标量的 pytree。请注意,data_fields 可以是 None,因为 JAX 将其识别为空的 pytree。

  • drop_fields (Sequence[str])

返回:

输入类 nodetype 在添加到 JAX 的 pytree 注册表后保持不变地返回。这个返回值允许 register_dataclass 进行部分评估并作为装饰器使用,如下例所示。

返回类型:

Typ

示例

>>> from dataclasses import dataclass
>>> from functools import partial
>>>
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

现在这个类已经注册,它可以与 jax.tree_util 中的函数一起使用:

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

特别是,此注册允许 m 在包裹在 jax.jit() 和其他 JAX 变换中的代码中无缝传递:

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)