jax.tree.structure

目录

jax.tree.structure#

jax.tree.structure(tree, is_leaf=None)[源代码][源代码]#

获取 pytree 的 treedef。

参数:
  • tree (Any) – 要获取其叶子的 pytree

  • is_leaf (None | Callable[[Any], bool]) – 一个可选指定的函数,将在每次扁平化步骤中被调用。它应返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,将整个子树视为叶子。

返回:

表示树结构的 PyTreeDef。

返回类型:

pytreedef

示例

>>> import jax
>>> jax.tree.structure([1, (2, 3), [4, 5]])
PyTreeDef([*, (*, *), [*, *]])