jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes: Sequence[int]) tuple[int, ...] [源代码][源代码]#
- jax.numpy.broadcast_shapes(*shapes: Sequence[int | core.Tracer]) tuple[int | core.Tracer, ...]
将输入形状广播到公共输出形状。
JAX 实现的
numpy.broadcast_shapes()
。JAX 使用 NumPy 风格的广播规则,你可以在 NumPy 广播 中了解更多。- 参数:
shapes – 指定为整数序列的0个或多个形状
- 返回:
作为整数元组的广播形状。
参见
jax.numpy.broadcast_arrays()
: 将数组广播到公共形状。jax.numpy.broadcast_to()
: 将数组广播到指定形状。
示例
一些兼容的形状:
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
不兼容的形状:
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]