jax.numpy.broadcast_arrays

jax.numpy.broadcast_arrays#

jax.numpy.broadcast_arrays(*args)[源代码][源代码]#

将数组广播到公共形状。

JAX 实现的 numpy.broadcast_arrays()。JAX 使用 NumPy 风格的广播规则,你可以在 NumPy 广播 中了解更多。

参数:

args (ArrayLike) – 零个或多个类数组对象,用于广播。

返回:

包含输入的广播副本的数组列表。

返回类型:

list[Array]

参见

示例

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)