jax.numpy.piecewise

目录

jax.numpy.piecewise#

jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[源代码][源代码]#

在整个定义域上分段评估一个函数。

JAX 实现的 numpy.piecewise(),基于 jax.lax.switch()

备注

numpy.piecewise() 不同,jax.numpy.piecewise() 要求 funclist 中的函数可被 JAX 追踪,因为它通过 jax.lax.switch() 实现。

参数:
  • x (ArrayLike) – 输入值的数组。

  • condlist (Array | Sequence[ArrayLike]) – 布尔数组或布尔数组的序列,对应于 funclist 中的函数。如果是数组的序列,每个数组的长度必须与 x 的长度匹配。

  • funclist (list[ArrayLike | Callable[..., Array]]) – 数组或函数的列表;必须与 condlist 长度相同,或者长度为 len(condlist) + 1,在这种情况下,最后一个条目是当所有条件都不为真时应用的默认值。或者,funclist 的条目可以是数值,在这种情况下,它们表示一个常数函数。

  • args – 附加参数会传递给 funclist 中的每个函数。

  • kwargs – 附加参数会传递给 funclist 中的每个函数。

返回:

在指定条件下对 x 进行函数求值得到的结果数组。

返回类型:

Array

参见

示例

这是一个函数示例,对于负值返回零,对于正值则线性增长:

>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0]
>>> funclist = [lambda x: 0 * x, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

funclist 也可以包含常量函数的简单标量值:

>>> condlist = [x < 0, x >= 0]
>>> funclist = [0, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

你可以通过向 funclist 添加额外条件来指定默认值:

>>> condlist = [x < -1, x > 1]
>>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0]
>>> jnp.piecewise(x, condlist, funclist)
Array([-3, -2,  -1,  0,  0,  0,  1,  2, 3], dtype=int32)

condlist 也可以是一个简单的标量条件数组,在这种情况下,关联的函数适用于整个范围。

>>> condlist = jnp.array([False, True, False])
>>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100]
>>> jnp.piecewise(x, condlist, funclist)
Array([-40, -30, -20, -10,   0,  10,  20,  30,  40], dtype=int32)