jax.tree_util.treedef_children#
- jax.tree_util.treedef_children(treedef)[源代码]#
返回直接子节点的 treedefs 列表
- 参数:
treedef (PyTreeDef) – 一个单独的 PyTreeDef
- 返回:
表示 treedef 子节点的 PyTreeDefs 列表。
- 返回类型:
list[PyTreeDef]
示例
>>> import jax >>> x = [(1, 2), 3, {'a': 4}] >>> treedef = jax.tree.structure(x) >>> jax.tree_util.treedef_children(treedef) [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] >>> _ == [jax.tree.structure(vals) for vals in x] True