jax.tree_util.build_tree#
- jax.tree_util.build_tree(treedef, xs)[源代码]#
从嵌套的可迭代结构构建树定义
- 参数:
treedef (PyTreeDef) – 要构建的 PyTreeDef 结构。
xs (Any) – 嵌套的可迭代对象,其数量与 treedef 的参数数量匹配
- 返回:
由 treedef 定义结构的 对象
- 返回类型:
Any
示例
>>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree)
build_tree
和jax.tree_util.tree_unflatten()
都可以用新值重建树,但build_tree
是以嵌套结构而不是扁平结构来接收这些值:>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}]