jax.tree_util.treedef_children

jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[源代码]#

返回直接子节点的 treedefs 列表

参数:

treedef (PyTreeDef) – 一个单独的 PyTreeDef

返回:

表示 treedef 子节点的 PyTreeDefs 列表。

返回类型:

list[PyTreeDef]

示例

>>> import jax
>>> x = [(1, 2), 3, {'a': 4}]
>>> treedef = jax.tree.structure(x)
>>> jax.tree_util.treedef_children(treedef)
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
>>> _ == [jax.tree.structure(vals) for vals in x]
True