编写自定义数组容器#
Numpy 的分发机制,在 numpy 版本 v1.16 中引入,是编写与 numpy API 兼容并提供 numpy 功能自定义实现的定制 N 维数组容器的推荐方法.应用包括 dask 数组,一个分布在多个节点上的 N 维数组,以及 cupy 数组,一个在 GPU 上的 N 维数组.
为了体验编写自定义数组容器的感觉,我们将从一个简单的例子开始,这个例子实用性较窄,但说明了所涉及的概念.
>>> import numpy as np
>>> class DiagonalArray:
... def __init__(self, N, value):
... self._N = N
... self._i = value
... def __repr__(self):
... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
... def __array__(self, dtype=None, copy=None):
... if copy is False:
... raise ValueError(
... "`copy=False` isn't supported. A copy is always created."
... )
... return self._i * np.eye(self._N, dtype=dtype)
我们的自定义数组可以像这样实例化:
>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)
我们可以使用 numpy.array
或 numpy.asarray
将其转换为 numpy 数组,这将调用其 __array__
方法以获取标准的 numpy.ndarray
.
>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
如果我们使用 numpy 函数对 arr
进行操作,numpy 将再次使用 __array__
接口将其转换为数组,然后以通常的方式应用该函数.
>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
注意返回类型是标准的 numpy.ndarray
.
>>> type(np.multiply(arr, 2))
<class 'numpy.ndarray'>
我们如何通过这个函数传递我们的自定义数组类型?Numpy 允许一个类通过接口 __array_ufunc__
和 __array_function__
来指示它希望以自定义定义的方式处理计算.让我们一次看一个,从 __array_ufunc__
开始.这个方法涵盖了 通用函数 (ufunc),这是一类函数,包括例如 numpy.multiply
和 numpy.sin
.
__array_ufunc__
接收:
ufunc
,类似于numpy.multiply
的函数method
是一个字符串,用于区分numpy.multiply(...)
及其变体,如numpy.multiply.outer
、numpy.multiply.accumulate
等.对于常见情况,``numpy.multiply(…)``,``method == ‘__call__’``.inputs
可以是不同类型的混合kwargs
,传递给函数的关键字参数
在这个例子中,我们只会处理 __call__
方法
>>> from numbers import Number
>>> class DiagonalArray:
... def __init__(self, N, value):
... self._N = N
... self._i = value
... def __repr__(self):
... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
... def __array__(self, dtype=None, copy=None):
... if copy is False:
... raise ValueError(
... "`copy=False` isn't supported. A copy is always created."
... )
... return self._i * np.eye(self._N, dtype=dtype)
... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
... if method == '__call__':
... N = None
... scalars = []
... for input in inputs:
... if isinstance(input, Number):
... scalars.append(input)
... elif isinstance(input, self.__class__):
... scalars.append(input._i)
... if N is not None:
... if N != input._N:
... raise TypeError("inconsistent sizes")
... else:
... N = input._N
... else:
... return NotImplemented
... return self.__class__(N, ufunc(*scalars, **kwargs))
... else:
... return NotImplemented
现在我们的自定义数组类型通过了 numpy 函数.
>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)
在这一点上 arr + 3
不起作用.
>>> arr + 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'
为了支持它,我们需要定义 Python 接口 __add__
、__lt__
等,以分派到相应的 ufunc.我们可以通过继承混合类 NDArrayOperatorsMixin
方便地实现这一点.
>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
... def __init__(self, N, value):
... self._N = N
... self._i = value
... def __repr__(self):
... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
... def __array__(self, dtype=None, copy=None):
... if copy is False:
... raise ValueError(
... "`copy=False` isn't supported. A copy is always created."
... )
... return self._i * np.eye(self._N, dtype=dtype)
... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
... if method == '__call__':
... N = None
... scalars = []
... for input in inputs:
... if isinstance(input, Number):
... scalars.append(input)
... elif isinstance(input, self.__class__):
... scalars.append(input._i)
... if N is not None:
... if N != input._N:
... raise TypeError("inconsistent sizes")
... else:
... N = input._N
... else:
... return NotImplemented
... return self.__class__(N, ufunc(*scalars, **kwargs))
... else:
... return NotImplemented
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)
现在让我们来处理 __array_function__
.我们将创建一个字典,将 numpy 函数映射到我们的自定义变体.
>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
... def __init__(self, N, value):
... self._N = N
... self._i = value
... def __repr__(self):
... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
... def __array__(self, dtype=None, copy=None):
... if copy is False:
... raise ValueError(
... "`copy=False` isn't supported. A copy is always created."
... )
... return self._i * np.eye(self._N, dtype=dtype)
... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
... if method == '__call__':
... N = None
... scalars = []
... for input in inputs:
... # In this case we accept only scalar numbers or DiagonalArrays.
... if isinstance(input, Number):
... scalars.append(input)
... elif isinstance(input, self.__class__):
... scalars.append(input._i)
... if N is not None:
... if N != input._N:
... raise TypeError("inconsistent sizes")
... else:
... N = input._N
... else:
... return NotImplemented
... return self.__class__(N, ufunc(*scalars, **kwargs))
... else:
... return NotImplemented
... def __array_function__(self, func, types, args, kwargs):
... if func not in HANDLED_FUNCTIONS:
... return NotImplemented
... # Note: this allows subclasses that don't override
... # __array_function__ to handle DiagonalArray objects.
... if not all(issubclass(t, self.__class__) for t in types):
... return NotImplemented
... return HANDLED_FUNCTIONS[func](*args, **kwargs)
...
一个方便的模式是定义一个装饰器 implements
,它可以用来将函数添加到 HANDLED_FUNCTIONS
中.
>>> def implements(np_function):
... "Register an __array_function__ implementation for DiagonalArray objects."
... def decorator(func):
... HANDLED_FUNCTIONS[np_function] = func
... return func
... return decorator
...
现在我们为 DiagonalArray
编写 numpy 函数的实现.为了完整性,为了支持 arr.sum()
的使用,添加一个调用 numpy.sum(self)
的方法 sum
,并对 mean
做同样处理.
>>> @implements(np.sum)
... def sum(arr):
... "Implementation of np.sum for DiagonalArray objects"
... return arr._i * arr._N
...
>>> @implements(np.mean)
... def mean(arr):
... "Implementation of np.mean for DiagonalArray objects"
... return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2
如果用户尝试使用 HANDLED_FUNCTIONS
中未包含的任何 numpy 函数,numpy 将引发 TypeError
,指示此操作不受支持.例如,连接两个 DiagonalArrays
不会产生另一个对角数组,因此不受支持.
>>> np.concatenate([arr, arr])
Traceback (most recent call last):
...
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]
此外,我们对 sum
和 mean
的实现不接受 numpy 实现中的可选参数.
>>> np.sum(arr, axis=0)
Traceback (most recent call last):
...
TypeError: sum() got an unexpected keyword argument 'axis'
用户总是可以选择使用 numpy.asarray
转换为普通的 numpy.ndarray
,然后从那里使用标准的 numpy.
>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
在这个例子中,``DiagonalArray`` 的实现仅处理 np.sum
和 np.mean
函数以简洁起见.Numpy API 中的许多其他函数也可以被包装,一个成熟的自定义数组容器可以明确支持 Numpy 提供的所有可包装函数.
Numpy 提供了一些工具来帮助测试在 numpy.testing.overrides
命名空间中实现了 __array_ufunc__
和 __array_function__
协议的自定义数组容器.
要检查一个 Numpy 函数是否可以通过 __array_ufunc__
重写,你可以使用 allows_array_ufunc_override
:
>>> from numpy.testing.overrides import allows_array_ufunc_override
>>> allows_array_ufunc_override(np.add)
True
同样地,你可以通过 allows_array_function_override
检查一个函数是否可以通过 __array_function__
被重写.
Numpy API 中每个可覆盖函数的列表也可以通过 get_overridable_numpy_array_functions
获取支持 __array_function__
协议的函数,以及通过 get_overridable_numpy_ufuncs
获取支持 __array_ufunc__
协议的函数.这两个函数都返回存在于 Numpy 公共 API 中的函数集合.用户定义的 ufuncs 或其他依赖于 Numpy 的库中定义的 ufuncs 不包含在这些集合中.
更多自定义数组容器的完整示例,请参阅 dask 源代码 和 cupy 源代码.
另请参阅 NEP 18.