jax.tree_util.tree_transpose

jax.tree_util.tree_transpose#

jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[源代码]#

别名 jax.tree.transpose()

参数:
  • outer_treedef (PyTreeDef)

  • inner_treedef (PyTreeDef | None)

  • pytree_to_transpose (Any)

返回类型:

Any