jax.tree_util.注册数据类#
- jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields, drop_fields=())[源代码]#
扩展了在 pytrees 中被视为内部节点的类型集合。
这与
register_pytree_with_keys_class
不同,因为 C++ 注册表使用优化的 C++ dataclass 内置而不是参数函数。有关注册 pytrees 的更多信息,请参阅 扩展 pytrees。
- 参数:
nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a
dataclass
: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed amongmeta_fields
ordata_fields
.meta_fields (Sequence[str]) – 辅助数据字段名称。这些字段 必须 包含静态、可哈希、不可变对象,因为这些对象用于生成 JIT 缓存键。特别是,
meta_fields
不能包含jax.Array
或numpy.ndarray
对象。data_fields (Sequence[str]) – 数据字段名称。这些字段 必须 是 JAX 兼容的对象,例如数组(
jax.Array
或numpy.ndarray
)、标量或其叶子为数组或标量的 pytree。请注意,data_fields
可以是None
,因为 JAX 将其识别为空的 pytree。drop_fields (Sequence[str])
- 返回:
输入类
nodetype
在添加到 JAX 的 pytree 注册表后保持不变地返回。这个返回值允许register_dataclass
进行部分评估并作为装饰器使用,如下例所示。- 返回类型:
Typ
示例
>>> from dataclasses import dataclass >>> from functools import partial >>> >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
现在这个类已经注册,它可以与
jax.tree_util
中的函数一起使用:>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m
在包裹在jax.jit()
和其他 JAX 变换中的代码中无缝传递:>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)