jax.tree_util.treedef_is_leaf

jax.tree_util.treedef_is_leaf#

jax.tree_util.treedef_is_leaf(treedef)[源代码]#

如果 treedef 表示一个叶子节点,则返回 True。

参数:

treedef (PyTreeDef) – 树以检查

返回:

如果 treedef 是一个叶子(即只有一个节点),则为 True;否则为 False。

返回类型:

bool

示例

>>> import jax
>>> tree1 = jax.tree.structure(1)
>>> jax.tree_util.treedef_is_leaf(tree1)
True
>>> tree2 = jax.tree.structure([1, 2])
>>> jax.tree_util.treedef_is_leaf(tree2)
False