jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
对数组进行逐元素操作的通用函数。
JAX 实现的
numpy.ufunc
。这是一个用于 JAX 支持的 NumPy 的 ufunc API 实现的类。大多数用户不需要实例化
ufunc
,而是会使用jax.numpy
中预定义的 ufuncs。要构建自己的 ufuncs,请参阅
jax.numpy.frompyfunc()
。示例
通用函数是对广播数组逐元素应用的函数,但它们还附带了许多额外的属性和方法。
As an example, consider the function
jax.numpy.add
. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
每个
ufunc
对象都包含多个属性,用于描述其行为:>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
Binary ufuncs like
jax.numpy.add
include number of methods to apply the function to arrays in different manners.方法
outer()
将函数应用于输入数组值的成对外积:>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The
ufunc.reduce()
方法对数组执行归约操作。例如,jnp.add.reduce()
等同于jnp.sum
:>>> jnp.add.reduce(x) Array(15, dtype=int32)
The
ufunc.accumulate()
方法对数组执行累积归约。例如,jnp.add.accumulate()
等价于jax.numpy.cumulative_sum()
:>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
The
ufunc.at()
方法在数组中的特定索引处应用函数;对于jnp.add
,计算类似于jax.lax.scatter_add()
:>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
而
ufunc.reduceat()
方法在数组的指定索引之间执行多个reduce
操作;对于jnp.add
,该操作类似于jax.ops.segment_sum()
:>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
在这种情况下,第一个元素是
x[0:2].sum()
,第二个元素是x[2:].sum()
。- 参数:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[源代码][源代码]#
方法
__init__
(func, /, nin, nout, *[, name, ...])accumulate
(a[, axis, dtype, out])从二进制ufunc派生的累积操作。
at
(a, indices[, b, inplace])通过指定的单目或双目ufunc更新数组的元素。
outer
(A, B, /)将函数应用于
A
和B
中的所有值对。reduce
(a[, axis, dtype, out, keepdims, ...])从二元函数派生的归约操作。
reduceat
(a, indices[, axis, dtype, out])通过二元ufunc减少指定索引之间的数组。
属性
identity
nargs
nin
nout