jax.lax.associative_scan#
- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[源代码][源代码]#
在并行环境中执行带有关联二元操作的扫描。
关于关联扫描的介绍,请参见 [BLE1990]。
- 参数:
fn (Callable) – 一个实现关联二元操作的Python可调用对象,其签名是
r = fn(a, b)。函数 fn 必须是关联的,即它必须满足方程fn(a, fn(b, c)) == fn(fn(a, b), c)。输入和结果是(可能是嵌套的Python树结构)与elems匹配的数组。每个数组在axis维度上有一个维度。fn 应在axis维度上逐元素应用(例如,通过使用jax.vmap()对逐元素函数进行操作)。结果r与两个输入a和b具有相同的形状(和结构)。elems – 一个(可能是嵌套的Python树结构)数组,每个数组的
axis维度大小为num_elems。reverse (bool) – 一个布尔值,表示扫描是否应相对于
axis维度反向进行。axis (int) – 一个整数,标识扫描应沿其进行的轴。
- 返回:
一个(可能是嵌套的Python树结构)与
elems形状和结构相同的数组,其中axis的第k个元素是通过递归应用fn来组合elems沿axis的前k个元素的结果。例如,给定elems = [a, b, c, ...],结果将是[a, fn(a, b), fn(fn(a, b), c), ...]。
示例 1:数组中数字的部分和:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
示例 2:矩阵数组的局部乘积
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
示例 3:数组中数字的反向部分和
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
[BLE1990]Blelloch, Guy E. 1990. “前缀和及其应用。”,技术报告 CMU-CS-90-190,卡内基梅隆大学计算机科学学院。