jax.numpy.kron

目录

jax.numpy.kron#

jax.numpy.kron(a, b)[源代码][源代码]#

计算两个输入数组的Kronecker积。

JAX 实现的 numpy.kron()

Kronecker 积是对两个任意大小的矩阵进行的操作,它产生一个分块矩阵。第一个矩阵 a 的每个元素都乘以第二个矩阵 b 的整个矩阵。如果 a 的形状是 (m, n),而 b 的形状是 (p, q),那么结果矩阵的形状将是 (m * p, n * q)。

参数:
  • a (ArrayLike) – 第一个输入数组,形状任意。

  • b (ArrayLike) – 具有任意形状的第二个输入数组。

返回:

一个表示输入 ab 的克罗内克积的新数组。输出的形状是输入形状的逐元素乘积。

返回类型:

Array

参见

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6],
...                [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5,  6, 10, 12],
       [ 7,  8, 14, 16],
       [15, 18, 20, 24],
       [21, 24, 28, 32]], dtype=int32)