自定义集合

对于许多问题,内置的 Dask 集合(dask.arraydask.dataframedask.bagdask.delayed)已经足够。对于那些它们无法满足的情况,可以创建自己的 Dask 集合。这里我们描述了实现 Dask 集合接口所需的方法。

备注

这被视为一个高级功能。在大多数情况下,内置的集合可能已经足够。

在阅读本文之前,您应该阅读并理解:

目录

Dask 集合接口

要创建您自己的 Dask 集合,您需要实现由 dask.typing.DaskCollection 协议定义的接口。请注意,没有必需的基类。

建议也阅读 核心 Dask 方法的内部机制 以了解此接口在 Dask 内部的使用方式。

收集协议

class dask.typing.DaskCollection(*args, **kwargs)[源代码]

定义 Dask 集合接口的协议。

abstract __dask_graph__() collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any][源代码]

Dask 任务图。

Dask 的核心集合(Array、DataFrame、Bag 和 Delayed)使用 HighLevelGraph 来表示集合任务图。也可以使用 Python 字典将任务图表示为低级图。

返回
映射

Dask 任务图。如果实例返回一个 dask.highlevelgraph.HighLevelGraph ,那么必须实现 __dask_layers__() 方法,如 HLGDaskCollection 协议所定义。

abstract __dask_keys__() list[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...], ForwardRef('NestedKeys')]][源代码]

任务图的输出键。

请注意,Dask 集合的键有额外的约束条件,这些约束条件在 任务图规范文档 中没有描述。以下将描述这些额外的约束条件。

所有键必须是非空字符串,或者是以非空字符串为第一个元素的元组,后面跟随零个或多个任意str、bytes、int、float或其元组。非空字符串通常被称为*集合名称*。dask包中嵌入的所有集合都有一个且仅有一个名称,但这不是必需的。

这些都是有效的输出:

  • []

  • ["x", "y"]

  • [[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]

返回
列表

一个可能嵌套的键列表,表示图的输出。计算后,结果将以相同的布局返回,键被其对应的输出替换。

__dask_optimize__: Any

给定一个图和键,返回一个新的优化图。

这个方法可以是 staticmethodclassmethod ,但不能是 instancemethod 。例如实现可以参见 Dask 核心集合中 __dask_optimize__ 的定义:dask.array.Arraydask.dataframe.DataFrame 等。

请注意,在调用 __dask_optimize__ 之前,图和键会被合并;因此,传递给此方法的图和键可能代表多个共享相同优化方法的集合。

参数
dsk

所有共享相同 __dask_optimize__ 方法的集合合并后的图表。

Sequence[Key]

共享相同 __dask_optimize__ 方法的所有集合中 __dask_keys__ 的输出列表。

**kwargs任何

从调用 computepersist 传递的额外关键字参数。可以根据需要使用或忽略。

返回
MutableMapping

优化后的 Dask 图。

abstract __dask_postcompute__() tuple[collections.abc.Callable, tuple][源代码]

最终化函数和用于构建最终结果的可选参数。

在计算过程中,集合中的每个键都会有一个内存中的结果,postcompute 函数将每个键的结果合并成一个最终的内存表示。例如,dask.array.Array 将每个块中的数组连接成一个最终的内存数组。

返回
PostComputeCallable

接收每个最终键结果序列以及可选参数的可调用对象。示例签名可能是 finalize(results: Sequence[Any], *args)

tuple[Any, …]

传递给函数的关键结果后的可选参数(PostComputeCallable*args 部分)。如果没有额外的参数要传递,那么这必须是一个空元组。

abstract __dask_postpersist__() tuple[dask.typing.PostPersistCallable, tuple][源代码]

重建函数和可选参数以构造一个持久化的集合。

另请参阅 dask.typing.PostPersistCallable 的文档。

返回
PostPersistCallable

可调用对象,用于重建集合。签名应为 rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None) (根据 PostPersistCallable 协议定义)。该可调用对象应返回一个等效的 Dask 集合,其键与 self 相同,但结果是通过不同的图计算的。在 dask.persist() 的情况下,新图将仅包含输出键和已计算的值。

tuple[Any, …]

传递给重新构建可调用对象的可选参数。如果没有要传递的附加参数,则这必须是一个空元组。

__dask_scheduler__: staticmethod

此对象使用的默认调度器 get

通常作为静态方法附加到类上,例如:

>>> import dask.threaded
>>> class MyCollection:
...     # Use the threaded scheduler by default
...     __dask_scheduler__ = staticmethod(dask.threaded.get)
abstract __dask_tokenize__() collections.abc.Hashable[源代码]

必须完全表示对象的值。

abstract compute(**kwargs: Any) Any[源代码]

计算这个 dask 集合。

这将一个懒惰的 Dask 集合转换为其内存等效项。例如,Dask 数组转换为 NumPy 数组,Dask 数据帧转换为 Pandas 数据帧。在调用此操作之前,整个数据集必须适合内存。

参数
调度器字符串,可选

使用哪种调度器,如“线程”、“同步”或“进程”。如果没有提供,默认首先检查全局设置,然后回退到集合的默认设置。

优化图bool, 可选

如果为 True [默认],图表在计算前会被优化。否则,图表将按原样运行。这对于调试很有用。

kwargs

传递给调度器函数的额外关键字参数。

返回
集合的计算结果。

参见

dask.compute
abstract persist(**kwargs: Any) dask.typing.CollType[源代码]

将此 dask 集合持久化到内存中

这将一个懒惰的 Dask 集合转换为一个具有相同元数据的 Dask 集合,但现在结果已完全计算或在后台主动计算。

函数的操作在很大程度上取决于活动的任务调度器。如果任务调度器支持异步计算,例如 dask.distributed 调度器的情况,那么 persist 将 立即 返回,并且返回值的任务图中将包含 Dask Future 对象。然而,如果任务调度器仅支持阻塞计算,那么对 persist 的调用将 阻塞,并且返回值的任务图中将包含具体的 Python 结果。

在使用分布式系统时,此功能特别有用,因为结果将保存在分布式内存中,而不是像使用 compute 那样返回到本地进程。

参数
调度器字符串,可选

使用哪种调度器,如“线程”、“同步”或“进程”。如果没有提供,默认首先检查全局设置,然后回退到集合的默认设置。

优化图bool, 可选

如果为 True [默认],图表在计算前会被优化。否则,图表将按原样运行。这对于调试很有用。

**kwargs

传递给调度器函数的额外关键字参数。

返回
新的 dask 集合由内存数据支持

参见

dask.persist
abstract visualize(filename: str = 'mydask', format: str | None = None, optimize_graph: bool = False, **kwargs: Any) DisplayObject | None[源代码]

使用 graphviz 渲染此对象任务图的计算。

需要安装 graphviz

参数
文件名str 或 None, 可选

要写入磁盘的文件名。如果提供的 filename 不包含扩展名,默认将使用 ‘.png’。如果 filename 为 None,则不会写入文件,我们将仅通过管道与 dot 通信。

格式{‘png’, ‘pdf’, ‘dot’, ‘svg’, ‘jpeg’, ‘jpg’}, 可选

写入输出文件的格式。默认是 ‘png’。

优化图bool, 可选

如果为 True,图表在渲染前会被优化。否则,图表将按原样显示。默认值为 False。

颜色: {None, ‘顺序’}, 可选

颜色节点的选项。提供 cmap= 关键字以获取额外的颜色映射

**kwargs

传递给 to_graphviz 的额外关键字参数。

返回
结果IPython.display.Image, IPython.display.SVG, 或 None

更多信息请参见 dask.dot.dot_graph。

参见

dask.visualize
dask.dot.dot_graph

注释

有关优化的更多信息,请参见此处:

http://www.aidoczh.com/dask/zh_CN/latest/optimize.html

示例

>>> x.visualize(filename='dask.pdf')  
>>> x.visualize(filename='dask.pdf', color='order')  

HLG 集合协议

基于 Dask 的 高层次图 的集合必须实现一个额外的方法,由该协议定义:

class dask.typing.HLGDaskCollection(*args, **kwargs)[源代码]

定义使用 HighLevelGraphs 的 Dask 集合的协议。

此协议与 DaskCollection 几乎完全相同,增加了 __dask_layers__ 方法(对于由高级图支持的集合是必需的)。

abstract __dask_layers__() collections.abc.Sequence[str][源代码]

HighLevelGraph 层的名称。

调度器 get 协议

SchedulerGetProtocol 定义了 Dask 集合的 __dask_scheduler__ 定义必须遵守的签名。

class dask.typing.SchedulerGetCallable(*args, **kwargs)[源代码]

定义 __dask_scheduler__ 可调用对象签名的协议。

__call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], keys: Union[collections.abc.Sequence[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]]], str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], **kwargs: Any) Any[源代码]

作为集合的默认调度器调用的方法。

参数
dsk

任务图。

对应所需数据的键。

**kwargs

附加参数。

返回
任何

keys 相关的结果

持久化后可调用协议

集合必须定义一个 __dask_postpersist__ 方法,该方法返回一个符合 PostPersistCallable 接口的可调用对象。

class dask.typing.PostPersistCallable(*args, **kwargs)[源代码]

定义 __dask_postpersist__ 可调用对象签名的协议。

__call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], *args: Any, rename: collections.abc.Mapping[str, str] | None = None) dask.typing.CollType_co[源代码]

调用以重建持久化集合的方法。

参数
dsk: 映射

一个包含至少由 __dask_keys__() 返回的输出键的映射。

*args任何

附加可选参数 如果没有额外的参数是必要的,它必须是一个空元组。

重命名Mapping[str, str], 可选

如果定义了,它表示输出键也可能发生变化;例如,如果之前 __dask_keys__() 的输出是 [('a', 0), ('a', 1)],在调用 rebuild(dsk, *extra_args, rename={'a': 'b'}) 之后,它必须变成 [('b', 0), ('b', 1)]rename 映射可能不包含集合名称;在这种情况下,相关键不会改变。它可能包含意外名称的替换,这些必须被忽略。

返回
集合

通过不同图计算出的具有相同键的等效Dask集合。

核心 Dask 方法的内部机制

Dask 有几种 核心 函数(以及相应的方法)实现了常见的操作:

  • compute: 将一个或多个 Dask 集合转换为其内存中的对应物

  • persist: 将一个或多个 Dask 集合转换为等效的 Dask 集合,其结果已经计算并缓存在内存中

  • optimize: 将一个或多个 Dask 集合转换为共享一个大型优化图的等效 Dask 集合

  • visualize: 给定一个或多个 Dask 集合,绘制出在调用 computepersist 时将传递给调度器的图。

在这里,我们简要描述这些函数的内部结构,以说明它们与上述接口的关系。

计算

compute 的操作可以分为三个阶段:

  1. 图合并与优化

    首先,各个集合被转换为一个单一的大图和嵌套的键列表。这个过程如何发生取决于 optimize_graph 关键字的值,每个函数都会接受这个关键字:

    • 如果 optimize_graphTrue (默认),那么集合首先根据它们的 __dask_optimize__ 方法进行分组。所有具有相同 __dask_optimize__ 方法的集合将其图合并并连接键,然后对合并后的图和键进行一次各自的 __dask_optimize__ 调用。然后合并生成的图。

    • 如果 optimize_graphFalse,那么所有图表将被合并,所有键将被连接。

    在这一阶段之后,有一个单一的大型图和嵌套的键列表,代表了所有的集合。

  2. 计算

    在图表合并并执行任何优化后,生成的较大图表和嵌套键列表将传递给调度器。所使用的调度器选择如下:

    • 如果 get 函数直接作为关键字指定,则使用该函数

    • 否则,如果设置了全局调度器,则使用该调度器

    • 否则,回退到给定集合的默认调度器。请注意,如果所有集合不共享相同的 __dask_scheduler__,则会引发错误。

    一旦确定了适当的调度器 get 函数,它就会与合并的图、键和额外的关键字参数一起调用。在此阶段之后,results 是一个嵌套的值列表。该列表的结构反映了 keys 的结构,每个键都被其对应的结果所替代。

  3. 后计算

    生成结果后,需要构建 compute 的输出值。这就是 __dask_postcompute__ 方法的作用。__dask_postcompute__ 返回两件事:

    • 一个 finalize 函数,它接收相应键的结果

    • 在结果之后传递给 finalize 的额外参数的元组

    为了构建输出,会遍历集合和结果列表,并对其各自的结果调用每个集合的最终处理函数。

在伪代码中,这个过程看起来如下:

def compute(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    if kwargs.pop('optimize_graph', True):
        # If optimization is turned on, group the collections by
        # optimization method, and apply each method only once to the merged
        # sub-graphs.
        optimization_groups = groupby_optimization_methods(collections)
        graphs = []
        for optimize_method, cols in optimization_groups:
            # Merge the graphs and keys for the subset of collections that
            # share this optimization method
            sub_graph = merge_graphs([x.__dask_graph__() for x in cols])
            sub_keys = [x.__dask_keys__() for x in cols]
            # kwargs are forwarded to ``__dask_optimize__`` from compute
            optimized_graph = optimize_method(sub_graph, sub_keys, **kwargs)
            graphs.append(optimized_graph)
        graph = merge_graphs(graphs)
    else:
        graph = merge_graphs([x.__dask_graph__() for x in collections])
    # Keys are always the same
    keys = [x.__dask_keys__() for x in collections]

    # 2. Computation
    # --------------
    # Determine appropriate get function based on collections, global
    # settings, and keyword arguments
    get = determine_get_function(collections, **kwargs)
    # Pass the merged graph, keys, and kwargs to ``get``
    results = get(graph, keys, **kwargs)

    # 3. Postcompute
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        finalize, extra_args = collection.__dask_postcompute__()
        out = finalize(res, **extra_args)
        output.append(out)

    # `dask.compute` always returns tuples
    return tuple(output)

持久化

Persist 与 compute 非常相似,除了返回值的创建方式不同。它也有三个阶段:

  1. 图合并与优化

    compute 中相同。

  2. 计算

    compute 相同,除了在分布式调度器的情况下,results 中的值是未来对象而不是实际值。

  3. Postpersist

    类似于 __dask_postcompute____dask_postpersist__ 用于在调用 persist 时重建值。 __dask_postpersist__ 返回两样东西:

    • 一个 rebuild 函数,它接收一个持久化的图。这个图的键与相应集合的 __dask_keys__ 相同,而值是计算结果(对于单机调度器)或未来对象(对于分布式调度器)。

    • 在图之后传递给 rebuild 的额外参数的元组

    要构建 persist 的输出,会遍历集合列表和结果,并针对每个集合的结果图调用其重建器。

在伪代码中,这看起来如下:

def persist(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...
    keys = ...

    # 2. Computation
    # --------------
    # **Same as in compute**
    results = ...

    # 3. Postpersist
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        # res has the same structure as keys
        keys = collection.__dask_keys__()
        # Get the computed graph for this collection.
        # Here flatten converts a nested list into a single list
        subgraph = {k: r for (k, r) in zip(flatten(keys), flatten(res))}

        # Rebuild the output dask collection with the computed graph
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(subgraph, *extra_args)

        output.append(out)

    # dask.persist always returns tuples
    return tuple(output)

优化

optimize 的操作可以分为两个阶段:

  1. 图合并与优化

    compute 中相同。

  2. 重建

    类似于 persistrebuild 函数和 __dask_postpersist__ 中的参数用于从优化后的图中重建等效的集合。

在伪代码中,这看起来如下:

def optimize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Rebuilding
    # -------------
    # Rebuild each dask collection using the same large optimized graph
    output = []
    for collection in collections:
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(graph, *extra_args)
        output.append(out)

    # dask.optimize always returns tuples
    return tuple(output)

可视化

Visualize 是 4 个核心功能中最简单的一个。它只有两个阶段:

  1. 图合并与优化

    compute 中相同。

  2. 图形绘制

    生成的合并图使用 graphviz 绘制,并输出到指定文件。

在伪代码中,这看起来如下:

def visualize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Graph Drawing
    # ----------------
    # Draw the graph with graphviz's `dot` tool and return the result.
    return dot_graph(graph, **kwargs)

将核心 Dask 方法添加到您的类中

定义上述接口将允许您的对象被 Dask 核心函数使用(dask.computedask.persistdask.visualize 等)。要添加这些函数的相应方法版本,您可以从 dask.base.DaskMethodsMixin 子类化,该子类基于上述接口添加了 computepersistvisualize 的实现。

示例 Dask 集合

这里我们创建一个表示元组的Dask集合。元组中的每个元素在图中都表示为一个任务。请注意,这仅用于说明目的 - 使用包含 dask.delayed 元素的普通元组也可以实现相同的使用体验:

# Saved as dask_tuple.py
import dask
from dask.base import DaskMethodsMixin, replace_name_in_key
from dask.optimization import cull

def tuple_optimize(dsk, keys, **kwargs):
    # We cull unnecessary tasks here. See
    # http://www.aidoczh.com/dask/en/stable/optimize.html for more
    # information on optimizations in Dask.
    dsk2, _ = cull(dsk, keys)
    return dsk2

# We subclass from DaskMethodsMixin to add common dask methods to
# our class (compute, persist, and visualize). This is nice but not
# necessary for creating a Dask collection (you can define them
# yourself).
class Tuple(DaskMethodsMixin):
    def __init__(self, dsk, keys):
        # The init method takes in a dask graph and a set of keys to use
        # as outputs.
        self._dsk = dsk
        self._keys = keys

    def __dask_graph__(self):
        return self._dsk

    def __dask_keys__(self):
        return self._keys

    # use the `tuple_optimize` function defined above
    __dask_optimize__ = staticmethod(tuple_optimize)

    # Use the threaded scheduler by default.
    __dask_scheduler__ = staticmethod(dask.threaded.get)

    def __dask_postcompute__(self):
        # We want to return the results as a tuple, so our finalize
        # function is `tuple`. There are no extra arguments, so we also
        # return an empty tuple.
        return tuple, ()

    def __dask_postpersist__(self):
        # We need to return a callable with the signature
        # rebuild(dsk, *extra_args, rename: Mapping[str, str] = None)
        return Tuple._rebuild, (self._keys,)

    @staticmethod
    def _rebuild(dsk, keys, *, rename=None):
        if rename is not None:
            keys = [replace_name_in_key(key, rename) for key in keys]
        return Tuple(dsk, keys)

    def __dask_tokenize__(self):
        # For tokenize to work we want to return a value that fully
        # represents this object. In this case it's the list of keys
        # to be computed.
        return self._keys

演示此类:

>>> from dask_tuple import Tuple
>>> from operator import add, mul

# Define a dask graph
>>> dsk = {"k0": 1,
...        ("x", "k1"): 2,
...        ("x", 1): (add, "k0", ("x", "k1")),
...        ("x", 2): (mul, ("x", "k1"), 2),
...        ("x", 3): (add, ("x", "k1"), ("x", 1))}

# The output keys for this graph.
# The first element of each tuple must be the same across the whole collection;
# the remainder are arbitrary, unique str, bytes, int, or floats
>>> keys = [("x", "k1"), ("x", 1), ("x", 2), ("x", 3)]

>>> x = Tuple(dsk, keys)

# Compute turns Tuple into a tuple
>>> x.compute()
(2, 3, 4, 5)

# Persist turns Tuple into a Tuple, with each task already computed
>>> x2 = x.persist()
>>> isinstance(x2, Tuple)
True
>>> x2.__dask_graph__()
{('x', 'k1'): 2, ('x', 1): 3, ('x', 2): 4, ('x', 3): 5}
>>> x2.compute()
(2, 3, 4, 5)

# Run-time typechecking
>>> from dask.typing import DaskCollection
>>> isinstance(x, DaskCollection)
True

检查对象是否为 Dask 集合

要检查一个对象是否是 Dask 集合,请使用 dask.base.is_dask_collection

>>> from dask.base import is_dask_collection
>>> from dask import delayed

>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)
True
>>> is_dask_collection(1)
False

实现确定性哈希

Dask 实现了自己的确定性哈希函数,用于根据参数的值生成键。这个函数作为 dask.base.tokenize 提供。许多常见类型已经实现了 tokenize,这些实现可以在 dask/base.py 中找到。

在创建自己的自定义类时,您可能需要注册一个 tokenize 实现。有两种方法可以做到这一点:

  1. __dask_tokenize__ 方法

    如果可能,建议定义 __dask_tokenize__ 方法。该方法不接受任何参数,并且应返回一个完全代表对象的值。在返回任何非平凡对象之前,最好从该方法中调用 dask.base.normalize_token

  2. 使用 dask.base.normalize_token 注册一个函数

    如果在类上定义方法不可行,或者你需要为已经注册的超类定制标记化函数(例如,如果你需要子类化内置类),你可以使用 normalize_token 调度注册一个标记化函数。该函数应具有如上所述的相同签名。

在这两种情况下,实现应该是相同的,只是定义的位置不同。

备注

Dask 集合和普通的 Python 对象都可以使用上述任一方法实现 tokenize

示例

>>> from dask.base import tokenize, normalize_token

# Define a tokenize implementation using a method.
>>> class Point:
...     def __init__(self, x, y):
...         self.x = x
...         self.y = y
...
...     def __dask_tokenize__(self):
...         # This tuple fully represents self
...         # Wrap non-trivial objects with normalize_token before returning them
...         return normalize_token(Point), self.x, self.y

>>> x = Point(1, 2)
>>> tokenize(x)
'5988362b6e07087db2bc8e7c1c8cc560'
>>> tokenize(x) == tokenize(x)  # token is idempotent
True
>>> tokenize(Point(1, 2)) == tokenize(Point(1, 2))  # token is deterministic
True
>>> tokenize(Point(1, 2)) == tokenize(Point(2, 1))  # tokens are unique
False


# Register an implementation with normalize_token
>>> class Point3D:
...     def __init__(self, x, y, z):
...         self.x = x
...         self.y = y
...         self.z = z

>>> @normalize_token.register(Point3D)
... def normalize_point3d(x):
...     return normalize_token(Point3D), x.x, x.y, x.z

>>> y = Point3D(1, 2, 3)
>>> tokenize(y)
'5a7e9c3645aa44cf13d021c14452152e'

更多示例,请参见 dask/base.py 或任何内置的 Dask 集合。