jax.tree_util.tree_unflatten

jax.tree_util.tree_unflatten#

jax.tree_util.tree_unflatten(treedef, leaves)[源代码]#

别名 jax.tree.unflatten()

参数:
  • treedef (PyTreeDef)

  • leaves (Iterable[Leaf])

返回类型:

Any