jax.tree.map#
- jax.tree.map(f, tree, *rest, is_leaf=None)[源代码][源代码]#
将一个多输入函数映射到 pytree 参数上,以生成一个新的 pytree。
- 参数:
f (Callable[..., Any]) – 接受
1 + len(rest)个参数的函数,将在 pytrees 的相应叶子节点上应用。tree (Any) – 一个要映射的 pytree,每个叶子提供第一个位置参数给
f。rest (Any) – 一个pytrees的元组,每个pytree与``tree``具有相同的结构,或者作为``tree``的前缀。
is_leaf (Callable[[Any], bool] | None) – 一个可选指定的函数,将在每次扁平化步骤中被调用。它应返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,将整个子树视为叶子。
- 返回:
一个与
tree结构相同的新 pytree,但每个叶子的值由f(x, *xs)给出,其中x是tree中相应叶子的值,xs是rest中相应节点的值的元组。- 返回类型:
Any
示例
>>> import jax >>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42}) {'x': 8, 'y': 43}
如果传递了多个输入,树的结构取自第一个输入;随后的输入只需要以
tree作为前缀:>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) [[5, 7, 9], [6, 1, 2]]