jax.numpy.block

目录

jax.numpy.block#

jax.numpy.block(arrays)[源代码][源代码]#

从嵌套的块列表中组装一个nd数组。

LAX-backend 实现的 numpy.block()

原始文档字符串如下。

在最内层列表中的块沿着最后一个维度(-1)连接(参见 concatenate),然后这些块沿着倒数第二个维度(-2)连接,依此类推,直到到达最外层列表。

块可以是任意维度,但不会按照常规规则进行广播。相反,会在前面插入大小为1的轴,以使所有块的 block.ndim 相同。这主要对处理标量有用,意味着像 np.block([v, 1]) 这样的代码是有效的,其中 v.ndim == 1

当嵌套列表达到两层深度时,这允许从其组成部分构建块矩阵。

Added in version 1.13.0.

参数:

arrays (nested list of array_like or scalars (but not tuples)) – 如果传递的是单个 ndarray 或标量(深度为 0 的嵌套列表),则原样返回(不会被复制)。元素形状必须在适当的轴上匹配(不进行广播),但必要时会在形状前添加前导 1 以使维度匹配。

返回:

block_array – 由给定块组装而成的数组。输出的维度等于以下最大值: * 所有输入的维度 * 输入列表嵌套的深度

返回类型:

ndarray