jax.tree_util.tree_unflatten# jax.tree_util.tree_unflatten(treedef, leaves)[源代码]# 别名 jax.tree.unflatten()。 参数: treedef (PyTreeDef) leaves (Iterable[Leaf]) 返回类型: Any