dask.数组.注册块类型
dask.数组.注册块类型¶
- dask.array.register_chunk_type(type)[源代码]¶
将给定类型注册为有效的块和下转型数组类型
- 参数
- 类型类型
要注册为 Dask 可以安全地作为块包装的鸭子数组类型,以及 Dask 在算术操作和 NumPy 函数/ufuncs 中不推迟处理的类型。
注释
一个
dask.array.Array
可以在其块中包含任何足够“NumPy-like”的数组。这些数组也被称为“鸭子数组”,因为它们匹配了NumPy数组API中最重要的部分,因此,在使用鸭子类型时,它们的行为方式相同。然而,为了使多种鸭子数组类型能够正确地互操作,它们需要在算术运算和NumPy函数/ufuncs中根据一个定义良好的类型转换层次结构( 参见NEP 13 )相互正确地让步。为了维护这个层次结构,Dask默认让步于所有其他鸭子数组类型,除了其内部注册表中的那些。默认情况下,这个注册表包含
cupy.ndarray
sparse.SparseArray
此函数用于将任何其他类型附加到此注册表中。如果某个类型不在此注册表中,但却是向下转换的类型(它在类型转换层次结构中位于
dask.array.Array
之下),则会由于所有操作数类型返回NotImplemented
而引发TypeError
。示例
使用一个模拟的
FlaggedArray
类作为示例,这是一个Dask未知的块类型,具有最小的鸭子数组API:>>> import numpy.lib.mixins >>> class FlaggedArray(numpy.lib.mixins.NDArrayOperatorsMixin): ... def __init__(self, a, flag=False): ... self.a = a ... self.flag = flag ... def __repr__(self): ... return f"Flag: {self.flag}, Array: " + repr(self.a) ... def __array__(self): ... return np.asarray(self.a) ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ... if method == '__call__': ... downcast_inputs = [] ... flag = False ... for input in inputs: ... if isinstance(input, self.__class__): ... flag = flag or input.flag ... downcast_inputs.append(input.a) ... elif isinstance(input, np.ndarray): ... downcast_inputs.append(input) ... else: ... return NotImplemented ... return self.__class__(ufunc(*downcast_inputs, **kwargs), flag) ... else: ... return NotImplemented ... @property ... def shape(self): ... return self.a.shape ... @property ... def ndim(self): ... return self.a.ndim ... @property ... def dtype(self): ... return self.a.dtype ... def __getitem__(self, key): ... return type(self)(self.a[key], self.flag) ... def __setitem__(self, key, value): ... self.a[key] = value
在注册
FlaggedArray
之前,两种类型都将尝试推迟到另一种类型:>>> import dask.array as da >>> da.ones(5) - FlaggedArray(np.ones(5), True) Traceback (most recent call last): ... TypeError: operand type(s) all returned NotImplemented ...
然而,一旦注册,Dask 将能够处理这种新类型的操作:
>>> da.register_chunk_type(FlaggedArray) >>> x = da.ones(5) - FlaggedArray(np.ones(5), True) >>> x dask.array<sub, shape=(5,), dtype=float64, chunksize=(5,), chunktype=dask.FlaggedArray> >>> x.compute() Flag: True, Array: array([0., 0., 0., 0., 0.])