jax.numpy.ufunc

jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#

对数组进行逐元素操作的通用函数。

JAX 实现的 numpy.ufunc

这是一个用于 JAX 支持的 NumPy 的 ufunc API 实现的类。大多数用户不需要实例化 ufunc,而是会使用 jax.numpy 中预定义的 ufuncs。

要构建自己的 ufuncs,请参阅 jax.numpy.frompyfunc()

示例

通用函数是对广播数组逐元素应用的函数,但它们还附带了许多额外的属性和方法。

As an example, consider the function jax.numpy.add. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:

>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)

每个 ufunc 对象都包含多个属性,用于描述其行为:

>>> jnp.add.nin  # number of inputs
2
>>> jnp.add.nout  # number of outputs
1
>>> jnp.add.identity  # identity value, or None if no identity exists
0

Binary ufuncs like jax.numpy.add include number of methods to apply the function to arrays in different manners.

方法 outer() 将函数应用于输入数组值的成对外积:

>>> jnp.add.outer(x, x)
Array([[ 2,  3,  4,  5,  6],
       [ 3,  4,  5,  6,  7],
       [ 4,  5,  6,  7,  8],
       [ 5,  6,  7,  8,  9],
       [ 6,  7,  8,  9, 10]], dtype=int32)

The ufunc.reduce() 方法对数组执行归约操作。例如,jnp.add.reduce() 等同于 jnp.sum

>>> jnp.add.reduce(x)
Array(15, dtype=int32)

The ufunc.accumulate() 方法对数组执行累积归约。例如,jnp.add.accumulate() 等价于 jax.numpy.cumulative_sum():

>>> jnp.add.accumulate(x)
Array([ 1,  3,  6, 10, 15], dtype=int32)

The ufunc.at() 方法在数组中的特定索引处应用函数;对于 jnp.add,计算类似于 jax.lax.scatter_add()

>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101,   2,   3,   4,   5], dtype=int32)

ufunc.reduceat() 方法在数组的指定索引之间执行多个 reduce 操作;对于 jnp.add ,该操作类似于 jax.ops.segment_sum()

>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)

在这种情况下,第一个元素是 x[0:2].sum(),第二个元素是 x[2:].sum()

参数:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[源代码][源代码]#
参数:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

方法

__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

从二进制ufunc派生的累积操作。

at(a, indices[, b, inplace])

通过指定的单目或双目ufunc更新数组的元素。

outer(A, B, /)

将函数应用于 AB 中的所有值对。

reduce(a[, axis, dtype, out, keepdims, ...])

从二元函数派生的归约操作。

reduceat(a, indices[, axis, dtype, out])

通过二元ufunc减少指定索引之间的数组。

属性

identity

nargs

nin

nout