jax.numpy.maximum#
- jax.numpy.maximum(x, y, /)[源代码][源代码]#
返回输入数组中逐元素的最大值。
JAX implementation of
numpy.maximum.- 参数:
x (ArrayLike) – 输入数组或标量。
y (ArrayLike) – 输入数组或标量。
x和y应具有相同的形状或可广播兼容。
- 返回:
包含
x和y逐元素最大值的数组。- 返回类型:
备注
- 对于每一对元素,
jnp.maximum返回: 如果两个元素都是有限数,则返回较大的那个。
nan如果一个元素是nan。
参见
jax.numpy.minimum(): 返回输入数组元素级的最小值。jax.numpy.fmax(): 返回输入数组按元素的最大值,忽略 NaNs。jax.numpy.amax(): 返回沿给定轴的数组元素的最大值。jax.numpy.nanmax(): 返回沿给定轴的数组元素的最大值,忽略 NaNs。
示例
x.shape == y.shape的输入:>>> x = jnp.array([1, -5, 3, 2]) >>> y = jnp.array([-2, 4, 7, -6]) >>> jnp.maximum(x, y) Array([1, 4, 7, 2], dtype=int32)
具有广播兼容性的输入:
>>> x1 = jnp.array([[-2, 5, 7, 4], ... [1, -6, 3, 8]]) >>> y1 = jnp.array([-5, 3, 6, 9]) >>> jnp.maximum(x1, y1) Array([[-2, 5, 7, 9], [ 1, 3, 6, 9]], dtype=int32)
包含
nan的输入:>>> nan = jnp.nan >>> x2 = jnp.array([nan, -3, 9]) >>> y2 = jnp.array([[4, -2, nan], ... [-3, -5, 10]]) >>> jnp.maximum(x2, y2) Array([[nan, -2., nan], [nan, -3., 10.]], dtype=float32)