jax.numpy.clip#
- jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[源代码][源代码]#
将数组值裁剪到指定范围。
JAX 实现的
numpy.clip()
。- 参数:
arr (ArrayLike | None) – 要裁剪的 N 维数组。
min (ArrayLike | None) – 裁剪范围的可选最小值;如果为
None``(默认),则结果不会被裁剪到任何最小值。如果指定,它应与 ``arr
和max
广播兼容。max (ArrayLike | None) – 裁剪范围的可选最大值;如果为
None``(默认),则结果不会被裁剪到任何最大值。如果指定,它应与 ``arr
和min
广播兼容。a (ArrayLike | DeprecatedArg) –
arr
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_min (ArrayLike | None | DeprecatedArg) –
min
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_max (ArrayLike | None | DeprecatedArg) –
max
参数的已弃用别名。如果使用,将导致DeprecationWarning
。
- 返回:
一个数组,包含来自
arr
的值,其中小于min
的值被设置为min
,大于max
的值被设置为max
。- 返回类型:
参见
jax.numpy.minimum()
: 计算两个数组元素级的最小值。jax.numpy.maximum()
: 计算两个数组元素级的最大值。
示例
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)