jax.tree_util.tree_leaves_with_path

jax.tree_util.tree_leaves_with_path#

jax.tree_util.tree_leaves_with_path(tree, is_leaf=None)[源代码]#

获取一个类似 tree_leaves 的 pytree 的叶子,并返回每个叶子的键路径。

参数:
  • tree (Any) – 一个 pytree。如果它包含自定义类型,则必须使用 register_pytree_with_keys 进行注册。

  • is_leaf (Callable[[Any], bool] | None)

返回:

一个键-叶对列表,每个键-叶对包含一个叶及其键路径。

返回类型:

list[tuple[KeyPath, Any]]