jax.numpy.dsplit

目录

jax.numpy.dsplit#

jax.numpy.dsplit(ary, indices_or_sections)[源代码][源代码]#

将数组深度分割为子数组。

JAX 实现的 numpy.dsplit()

详情请参阅 jax.numpy.split() 的文档。dsplit 等同于 split 并设置 axis=2

示例

>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0  1  2  3]]

 [[ 4  5  6  7]]

 [[ 8  9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]

 [[4 5]]

 [[8 9]]]
>>> print(x2)
[[[ 2  3]]

 [[ 6  7]]

 [[10 11]]]

参见

参数:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

返回类型:

list[Array]