jax.Array.at#
- abstract property Array.at[源代码]#
用于索引更新功能的辅助属性。
at
属性提供了一个与就地数组修改功能等价的纯函数实现。特别是:
备用语法
等效的就地表达式
x = x.at[idx].set(y)
x[idx] = y
x = x.at[idx].add(y)
x[idx] += y
x = x.at[idx].multiply(y)
x[idx] *= y
x = x.at[idx].divide(y)
x[idx] /= y
x = x.at[idx].power(y)
x[idx] **= y
x = x.at[idx].min(y)
x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y)
x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc)
ufunc.at(x, idx)
x = x.at[idx].get()
x = x[idx]
所有
x.at
表达式都不会修改原始的x
;相反,它们返回的是x
的修改后的副本。然而,在jit()
编译的函数内部,像x = x.at[idx].set(y)
这样的表达式可以保证是原地应用的。与NumPy的就地操作如
x[idx] += y
不同,如果有多个索引指向同一位置,所有更新都将被应用(NumPy只会应用最后一次更新,而不是应用所有更新。)冲突更新的应用顺序是实现定义的,并且可能是非确定性的(例如,由于某些硬件平台上的并发性。)默认情况下,JAX 假设所有索引都在边界内。可以通过
mode
参数指定替代的越界索引语义(见下文)。- 参数:
mode (str) – 指定越界索引模式。选项包括: -
"promise_in_bounds"
:(默认)用户承诺索引在边界内。indices_are_sorted (bool) – 如果为 True,实现将假定传递给
at[]
的索引是按升序排列的,这可以在某些后端上实现更高效的执行。unique_indices (bool) – 如果为 True,实现将假设传递给
at[]
的索引是唯一的,这可以在某些后端上实现更高效的执行。fill_value (Any) – 仅适用于
get()
方法:当 mode 为'fill'
时,返回超出边界切片的填充值。否则忽略。默认值为:对于不精确类型为NaN
,对于有符号类型为最大负值,对于无符号类型为最大正值,对于布尔类型为True
。
示例
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # out-of-bounds indices are ignored Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)