jax.tree 模块

目录

jax.tree 模块#

用于处理树状容器数据结构的工具。

The jax.tree 命名空间包含 jax.tree_util 中实用程序的别名。

函数列表#

all(tree, *[, is_leaf])

对树的所有叶子调用 all()。

flatten(tree[, is_leaf])

展平一个 pytree。

leaves(tree[, is_leaf])

获取 pytree 的叶子节点。

map(f, tree, *rest[, is_leaf])

将一个多输入函数映射到 pytree 参数上,以生成一个新的 pytree。

reduce()

对树的叶子调用 reduce()。

structure(tree[, is_leaf])

获取 pytree 的 treedef。

transpose(outer_treedef, inner_treedef, ...)

将具有树结构(外部,内部)的树转换为具有结构(内部,外部)的树。

unflatten(treedef, leaves)

从 treedef 和叶子重建一个 pytree。