jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
对数组进行逐元素操作的通用函数。
JAX 实现的
numpy.ufunc。这是一个用于 JAX 支持的 NumPy 的 ufunc API 实现的类。大多数用户不需要实例化
ufunc,而是会使用jax.numpy中预定义的 ufuncs。要构建自己的 ufuncs,请参阅
jax.numpy.frompyfunc()。示例
通用函数是对广播数组逐元素应用的函数,但它们还附带了许多额外的属性和方法。
As an example, consider the function
jax.numpy.add. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
每个
ufunc对象都包含多个属性,用于描述其行为:>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
Binary ufuncs like
jax.numpy.addinclude number of methods to apply the function to arrays in different manners.方法
outer()将函数应用于输入数组值的成对外积:>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The
ufunc.reduce()方法对数组执行归约操作。例如,jnp.add.reduce()等同于jnp.sum:>>> jnp.add.reduce(x) Array(15, dtype=int32)
The
ufunc.accumulate()方法对数组执行累积归约。例如,jnp.add.accumulate()等价于jax.numpy.cumulative_sum():>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
The
ufunc.at()方法在数组中的特定索引处应用函数;对于jnp.add,计算类似于jax.lax.scatter_add():>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
而
ufunc.reduceat()方法在数组的指定索引之间执行多个reduce操作;对于jnp.add,该操作类似于jax.ops.segment_sum():>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
在这种情况下,第一个元素是
x[0:2].sum(),第二个元素是x[2:].sum()。- 参数:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[源代码][源代码]#
方法
__init__(func, /, nin, nout, *[, name, ...])accumulate(a[, axis, dtype, out])从二进制ufunc派生的累积操作。
at(a, indices[, b, inplace])通过指定的单目或双目ufunc更新数组的元素。
outer(A, B, /)将函数应用于
A和B中的所有值对。reduce(a[, axis, dtype, out, keepdims, ...])从二元函数派生的归约操作。
reduceat(a, indices[, axis, dtype, out])通过二元ufunc减少指定索引之间的数组。
属性
identitynargsninnout