jax.numpy.broadcast_shapes

jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes: Sequence[int]) tuple[int, ...][源代码][源代码]#
jax.numpy.broadcast_shapes(*shapes: Sequence[int | core.Tracer]) tuple[int | core.Tracer, ...]

将输入形状广播到公共输出形状。

JAX 实现的 numpy.broadcast_shapes()。JAX 使用 NumPy 风格的广播规则,你可以在 NumPy 广播 中了解更多。

参数:

shapes – 指定为整数序列的0个或多个形状

返回:

作为整数元组的广播形状。

参见

示例

一些兼容的形状:

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

不兼容的形状:

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]