jax.numpy.unstack

目录

jax.numpy.unstack#

jax.numpy.unstack(x, /, *, axis=0)[源代码][源代码]#

沿着给定轴将数组分割成一系列数组。

LAX-backend 实现的 numpy.unstack()

原始文档字符串如下。

axis 参数指定数组将被分割的维度。例如,如果 axis=0``(默认),它将是第一个维度,如果 ``axis=-1,它将是最后一个维度。

结果是沿着 axis 分割的数组元组。

Added in version 2.1.0.

参数:
  • x (ndarray) – 要解堆叠的数组。

  • axis (int, optional) – 数组将沿其分割的轴。默认值:0

返回:

unstacked – 未堆叠的数组。

返回类型:

tuple of ndarrays