jax.tree.unflatten

目录

jax.tree.unflatten#

jax.tree.unflatten(treedef, leaves)[源代码][源代码]#

从 treedef 和叶子重建一个 pytree。

tree_flatten() 相反。

参数:
  • treedef (tree_util.PyTreeDef) – 重建的树定义

  • leaves (Iterable[tree_util.Leaf]) – 用于重建的叶子可迭代对象。该可迭代对象必须与treedef的叶子匹配。

返回:

重建的 pytree,包含了根据 treedef 描述的结构放置的 leaves

返回类型:

Any

示例

>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> newvals = [100, 200, 300, 400, 500]
>>> jax.tree.unflatten(treedef, newvals)
[100, (200, 300), [400, 500]]