jax.numpy.apply_along_axis

jax.numpy.apply_along_axis#

jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[源代码][源代码]#

沿给定轴对1-D切片应用函数。

LAX-backend 对 numpy.apply_along_axis() 的实现。

原始文档字符串如下。

执行 func1d(a, *args, **kwargs) 其中 func1d 对 1-D 数组进行操作,而 a 是沿 axisarr 进行 1-D 切片的结果。

这等同于(但比以下使用 ndindexs_ 更快),它将 iijjkk 分别设置为一个索引元组:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        f = func1d(arr[ii + s_[:,] + kk])
        Nj = f.shape
        for jj in ndindex(Nj):
            out[ii + jj + kk] = f[jj]

同样地,消除内部循环,这可以表示为:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        out[ii + s_[...,] + kk] = func1d(arr[ii + s_[:,] + kk])
参数:
  • func1d (function (M,) -> (Nj...)) – 此函数应接受一维数组。它将应用于沿指定轴的 arr 的一维切片。

  • axis (integer) – 沿其对 arr 进行切片的轴。

  • arr (ndarray (Ni..., M, Nk...)) – 输入数组。

  • args (any) – func1d 的额外参数。

  • kwargs (any) – func1d 的附加命名参数。

返回:

out – 输出数组。out 的形状与 arr 的形状相同,除了沿 axis 维度。该维度被移除,并替换为与 func1d 返回值的形状相等的新维度。因此,如果 func1d 返回一个标量,out 将比 arr 少一个维度。

返回类型:

ndarray (Ni…, Nj…, Nk…)