jax.tree_util
模块#
用于处理树状容器数据结构的工具。
此模块提供了一组用于处理树状数据结构的实用函数,例如嵌套的元组、列表和字典。我们称这些结构为 pytrees。它们是树,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,任何 pytree 的 pytree 也是 pytree),并且可以递归操作(对象身份等价性不会被映射操作保留,并且结构不能包含引用循环)。
被认为是 pytree 节点的 Python 类型集合(例如,可以被映射,而不是被视为叶子)是可扩展的。存在一个单一的模块级类型注册表,并且类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上对本文件中的实用函数变得透明。
该模块的主要目的是实现用户定义的数据结构与 JAX 变换(例如 jit)之间的互操作性。这并不是一个通用的树状数据结构处理库。
参见 JAX pytrees 笔记 中的示例。
函数列表#
|
一个在 pytrees 中工作的 functools.partial 版本。 |
|
测试给定可迭代对象中的所有元素是否都是叶子节点。 |
|
从嵌套的可迭代结构构建树定义 |
|
扩展了在 pytrees 中被视为内部节点的类型集合。 |
|
扩展了在 pytrees 中被视为内部节点的类型集合。 |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
|
|
扩展了在 pytrees 中被视为内部节点的类型集合。 |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
|
|
将 cls 注册为没有叶子的 pytree。 |
|
类似于 |
|
获取一个类似 |
|
将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。 |
|
返回直接子节点的 treedefs 列表 |
|
如果 treedef 表示一个叶子节点,则返回 True。 |
|
从子树定义的可迭代对象中创建一个元组树定义。 |
类型变量。 |
|
内置的不可变序列。 |
|
|
帮助程序,用于美化打印键的元组。 |
遗留 API#
这些API现在通过 jax.tree
访问。
|
别名 |
|
|
|
|
|
别名 |
|
|
|
|
|
|
|