Source code for dask.optimization

from __future__ import annotations

import math
import numbers
from collections.abc import Iterable
from enum import Enum
from typing import Any

from dask import config, core, utils
from dask._task_spec import GraphNode
from dask.base import normalize_token, tokenize
from dask.core import (
    flatten,
    get_dependencies,
    ishashable,
    istask,
    reverse_dict,
    subs,
    toposort,
)
from dask.tokenize import normalize_token, tokenize
from dask.typing import Graph, Key


[docs]def cull(dsk, keys): """Return new dask with only the tasks required to calculate keys. In other words, remove unnecessary tasks from dask. ``keys`` may be a single key or list of keys. Examples -------- >>> def inc(x): ... return x + 1 >>> def add(x, y): ... return x + y >>> d = {'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)} >>> dsk, dependencies = cull(d, 'out') >>> dsk # doctest: +ELLIPSIS {'out': (<function add at ...>, 'x', 10), 'x': 1} >>> dependencies # doctest: +ELLIPSIS {'out': ['x'], 'x': []} Returns ------- dsk: culled dask graph dependencies: Dict mapping {key: [deps]}. Useful side effect to accelerate other optimizations, notably fuse. """ if not isinstance(keys, (list, set)): keys = [keys] seen = set() dependencies = dict() out = {} work = list(set(flatten(keys))) while work: new_work = [] for k in work: dependencies_k = get_dependencies(dsk, k, as_list=True) # fuse needs lists out[k] = dsk[k] dependencies[k] = dependencies_k for d in dependencies_k: if d not in seen: seen.add(d) new_work.append(d) work = new_work return out, dependencies
def default_fused_linear_keys_renamer(keys): """Create new keys for fused tasks""" typ = type(keys[0]) if typ is str: names = [utils.key_split(x) for x in keys[:0:-1]] names.append(keys[0]) return "-".join(names) elif typ is tuple and len(keys[0]) > 0 and isinstance(keys[0][0], str): names = [utils.key_split(x) for x in keys[:0:-1]] names.append(keys[0][0]) return ("-".join(names),) + keys[0][1:] else: return None def fuse_linear(dsk, keys=None, dependencies=None, rename_keys=True): """Return new dask graph with linear sequence of tasks fused together. If specified, the keys in ``keys`` keyword argument are *not* fused. Supply ``dependencies`` from output of ``cull`` if available to avoid recomputing dependencies. **This function is mostly superseded by ``fuse``** Parameters ---------- dsk: dict keys: list dependencies: dict, optional {key: [list-of-keys]}. Must be a list to provide count of each key This optional input often comes from ``cull`` rename_keys: bool or func, optional Whether to rename fused keys with ``default_fused_linear_keys_renamer`` or not. Renaming fused keys can keep the graph more understandable and comprehensive, but it comes at the cost of additional processing. If False, then the top-most key will be used. For advanced usage, a func is also accepted, ``new_key = rename_keys(fused_key_list)``. Examples -------- >>> def inc(x): ... return x + 1 >>> def add(x, y): ... return x + y >>> d = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')} >>> dsk, dependencies = fuse(d) >>> dsk # doctest: +SKIP {'a-b-c': (inc, (inc, 1)), 'c': 'a-b-c'} >>> dsk, dependencies = fuse(d, rename_keys=False) >>> dsk # doctest: +ELLIPSIS {'c': (<function inc at ...>, (<function inc at ...>, 1))} >>> dsk, dependencies = fuse(d, keys=['b'], rename_keys=False) >>> dsk # doctest: +ELLIPSIS {'b': (<function inc at ...>, 1), 'c': (<function inc at ...>, 'b')} Returns ------- dsk: output graph with keys fused dependencies: dict mapping dependencies after fusion. Useful side effect to accelerate other downstream optimizations. """ if keys is not None and not isinstance(keys, set): if not isinstance(keys, list): keys = [keys] keys = set(flatten(keys)) if dependencies is None: dependencies = {k: get_dependencies(dsk, k, as_list=True) for k in dsk} # locate all members of linear chains child2parent = {} unfusible = set() for parent in dsk: deps = dependencies[parent] has_many_children = len(deps) > 1 for child in deps: if keys is not None and child in keys: unfusible.add(child) elif child in child2parent: del child2parent[child] unfusible.add(child) elif has_many_children: unfusible.add(child) elif child not in unfusible: child2parent[child] = parent # construct the chains from ancestor to descendant chains = [] parent2child = dict(map(reversed, child2parent.items())) while child2parent: child, parent = child2parent.popitem() chain = [child, parent] while parent in child2parent: parent = child2parent.pop(parent) del parent2child[parent] chain.append(parent) chain.reverse() while child in parent2child: child = parent2child.pop(child) del child2parent[child] chain.append(child) chains.append(chain) dependencies = {k: set(v) for k, v in dependencies.items()} if rename_keys is True: key_renamer = default_fused_linear_keys_renamer elif rename_keys is False: key_renamer = None else: key_renamer = rename_keys # create a new dask with fused chains rv = {} fused = set() aliases = set() is_renamed = False for chain in chains: if key_renamer is not None: new_key = key_renamer(chain) is_renamed = ( new_key is not None and new_key not in dsk and new_key not in rv ) child = chain.pop() val = dsk[child] while chain: parent = chain.pop() dependencies[parent].update(dependencies.pop(child)) dependencies[parent].remove(child) val = subs(dsk[parent], child, val) fused.add(child) child = parent fused.add(child) if is_renamed: rv[new_key] = val rv[child] = new_key dependencies[new_key] = dependencies[child] dependencies[child] = {new_key} aliases.add(child) else: rv[child] = val for key, val in dsk.items(): if key not in fused: rv[key] = val if aliases: for key, deps in dependencies.items(): for old_key in deps & aliases: new_key = rv[old_key] deps.remove(old_key) deps.add(new_key) rv[key] = subs(rv[key], old_key, new_key) if keys is not None: for key in aliases - keys: del rv[key] del dependencies[key] return rv, dependencies def _flat_set(x): if x is None: return set() elif isinstance(x, set): return x elif not isinstance(x, (list, set)): x = [x] return set(x)
[docs]def inline(dsk, keys=None, inline_constants=True, dependencies=None): """Return new dask with the given keys inlined with their values. Inlines all constants if ``inline_constants`` keyword is True. Note that the constant keys will remain in the graph, to remove them follow ``inline`` with ``cull``. Examples -------- >>> def inc(x): ... return x + 1 >>> def add(x, y): ... return x + y >>> d = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')} >>> inline(d) # doctest: +ELLIPSIS {'x': 1, 'y': (<function inc at ...>, 1), 'z': (<function add at ...>, 1, 'y')} >>> inline(d, keys='y') # doctest: +ELLIPSIS {'x': 1, 'y': (<function inc at ...>, 1), 'z': (<function add at ...>, 1, (<function inc at ...>, 1))} >>> inline(d, keys='y', inline_constants=False) # doctest: +ELLIPSIS {'x': 1, 'y': (<function inc at ...>, 'x'), 'z': (<function add at ...>, 'x', (<function inc at ...>, 'x'))} """ if dependencies and isinstance(next(iter(dependencies.values())), list): dependencies = {k: set(v) for k, v in dependencies.items()} keys = _flat_set(keys) if dependencies is None: dependencies = {k: get_dependencies(dsk, k) for k in dsk} if inline_constants: keys.update( k for k, v in dsk.items() if (ishashable(v) and v in dsk) or (not dependencies[k] and not istask(v)) ) # Keys may depend on other keys, so determine replace order with toposort. # The values stored in `keysubs` do not include other keys. replaceorder = toposort( {k: dsk[k] for k in keys if k in dsk}, dependencies=dependencies ) keysubs = {} for key in replaceorder: val = dsk[key] for dep in keys & dependencies[key]: if dep in keysubs: replace = keysubs[dep] else: replace = dsk[dep] val = subs(val, dep, replace) keysubs[key] = val # Make new dask with substitutions dsk2 = keysubs.copy() for key, val in dsk.items(): if key not in dsk2: for item in keys & dependencies[key]: val = subs(val, item, keysubs[item]) dsk2[key] = val return dsk2
[docs]def inline_functions( dsk, output, fast_functions=None, inline_constants=False, dependencies=None ): """Inline cheap functions into larger operations Examples -------- >>> inc = lambda x: x + 1 >>> add = lambda x, y: x + y >>> double = lambda x: x * 2 >>> dsk = {'out': (add, 'i', 'd'), # doctest: +SKIP ... 'i': (inc, 'x'), ... 'd': (double, 'y'), ... 'x': 1, 'y': 1} >>> inline_functions(dsk, [], [inc]) # doctest: +SKIP {'out': (add, (inc, 'x'), 'd'), 'd': (double, 'y'), 'x': 1, 'y': 1} Protect output keys. In the example below ``i`` is not inlined because it is marked as an output key. >>> inline_functions(dsk, ['i', 'out'], [inc, double]) # doctest: +SKIP {'out': (add, 'i', (double, 'y')), 'i': (inc, 'x'), 'x': 1, 'y': 1} """ if not fast_functions: return dsk output = set(output) fast_functions = set(fast_functions) if dependencies is None: dependencies = {k: get_dependencies(dsk, k) for k in dsk} dependents = reverse_dict(dependencies) def inlinable(key, task): if ( not isinstance(task, GraphNode) and istask(task) and key not in output and dependents[key] ): try: if functions_of(task).issubset(fast_functions) and not any( isinstance(dsk[d], GraphNode) for d in dependents[key] ): return True except TypeError: pass return False keys = [k for k, v in dsk.items() if inlinable(k, v)] if keys: dsk = inline( dsk, keys, inline_constants=inline_constants, dependencies=dependencies ) for k in keys: del dsk[k] return dsk
def unwrap_partial(func): while hasattr(func, "func"): func = func.func return func
[docs]def functions_of(task): """Set of functions contained within nested task Examples -------- >>> inc = lambda x: x + 1 >>> add = lambda x, y: x + y >>> mul = lambda x, y: x * y >>> task = (add, (mul, 1, 2), (inc, 3)) # doctest: +SKIP >>> functions_of(task) # doctest: +SKIP set([add, mul, inc]) """ funcs = set() work = [task] sequence_types = {list, tuple} while work: new_work = [] for task in work: if type(task) in sequence_types: if istask(task): funcs.add(unwrap_partial(task[0])) new_work.extend(task[1:]) else: new_work.extend(task) work = new_work return funcs
def default_fused_keys_renamer(keys, max_fused_key_length=120): """Create new keys for ``fuse`` tasks. The optional parameter `max_fused_key_length` is used to limit the maximum string length for each renamed key. If this parameter is set to `None`, there is no limit. """ it = reversed(keys) first_key = next(it) typ = type(first_key) if max_fused_key_length: # Take into account size of hash suffix max_fused_key_length -= 5 def _enforce_max_key_limit(key_name): if max_fused_key_length and len(key_name) > max_fused_key_length: name_hash = f"{hash(key_name):x}"[:4] key_name = f"{key_name[:max_fused_key_length]}-{name_hash}" return key_name if typ is str: first_name = utils.key_split(first_key) names = {utils.key_split(k) for k in it} names.discard(first_name) names = sorted(names) names.append(first_key) concatenated_name = "-".join(names) return _enforce_max_key_limit(concatenated_name) elif typ is tuple and len(first_key) > 0 and isinstance(first_key[0], str): first_name = utils.key_split(first_key) names = {utils.key_split(k) for k in it} names.discard(first_name) names = sorted(names) names.append(first_key[0]) concatenated_name = "-".join(names) return (_enforce_max_key_limit(concatenated_name),) + first_key[1:] # PEP-484 compliant singleton constant # https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions class Default(Enum): token = 0 def __repr__(self) -> str: return "<default>" _default = Default.token
[docs]def fuse( dsk, keys=None, dependencies=None, ave_width=_default, max_width=_default, max_height=_default, max_depth_new_edges=_default, rename_keys=_default, fuse_subgraphs=_default, ): """Fuse tasks that form reductions; more advanced than ``fuse_linear`` This trades parallelism opportunities for faster scheduling by making tasks less granular. It can replace ``fuse_linear`` in optimization passes. This optimization applies to all reductions--tasks that have at most one dependent--so it may be viewed as fusing "multiple input, single output" groups of tasks into a single task. There are many parameters to fine tune the behavior, which are described below. ``ave_width`` is the natural parameter with which to compare parallelism to granularity, so it should always be specified. Reasonable values for other parameters will be determined using ``ave_width`` if necessary. Parameters ---------- dsk: dict dask graph keys: list or set, optional Keys that must remain in the returned dask graph dependencies: dict, optional {key: [list-of-keys]}. Must be a list to provide count of each key This optional input often comes from ``cull`` ave_width: float (default 1) Upper limit for ``width = num_nodes / height``, a good measure of parallelizability. dask.config key: ``optimization.fuse.ave-width`` max_width: int (default infinite) Don't fuse if total width is greater than this. dask.config key: ``optimization.fuse.max-width`` max_height: int or None (default None) Don't fuse more than this many levels. Set to None to dynamically adjust to ``1.5 + ave_width * log(ave_width + 1)``. dask.config key: ``optimization.fuse.max-height`` max_depth_new_edges: int or None (default None) Don't fuse if new dependencies are added after this many levels. Set to None to dynamically adjust to ave_width * 1.5. dask.config key: ``optimization.fuse.max-depth-new-edges`` rename_keys: bool or func, optional (default True) Whether to rename the fused keys with ``default_fused_keys_renamer`` or not. Renaming fused keys can keep the graph more understandable and comprehensive, but it comes at the cost of additional processing. If False, then the top-most key will be used. For advanced usage, a function to create the new name is also accepted. dask.config key: ``optimization.fuse.rename-keys`` fuse_subgraphs : bool or None, optional (default None) Whether to fuse multiple tasks into ``SubgraphCallable`` objects. Set to None to let the default optimizer of individual dask collections decide. If no collection-specific default exists, None defaults to False. dask.config key: ``optimization.fuse.subgraphs`` Returns ------- dsk output graph with keys fused dependencies dict mapping dependencies after fusion. Useful side effect to accelerate other downstream optimizations. """ # Perform low-level fusion unless the user has # specified False explicitly. if config.get("optimization.fuse.active") is False: return dsk, dependencies if keys is not None and not isinstance(keys, set): if not isinstance(keys, list): keys = [keys] keys = set(flatten(keys)) # Read defaults from dask.yaml and/or user-defined config file if ave_width is _default: ave_width = config.get("optimization.fuse.ave-width") assert ave_width is not _default if max_height is _default: max_height = config.get("optimization.fuse.max-height") assert max_height is not _default if max_depth_new_edges is _default: max_depth_new_edges = config.get("optimization.fuse.max-depth-new-edges") assert max_depth_new_edges is not _default if max_depth_new_edges is None: max_depth_new_edges = ave_width * 1.5 if max_width is _default: max_width = config.get("optimization.fuse.max-width") assert max_width is not _default if max_width is None: max_width = 1.5 + ave_width * math.log(ave_width + 1) if fuse_subgraphs is _default: fuse_subgraphs = config.get("optimization.fuse.subgraphs") assert fuse_subgraphs is not _default if fuse_subgraphs is None: fuse_subgraphs = False if not ave_width or not max_height: return dsk, dependencies if rename_keys is _default: rename_keys = config.get("optimization.fuse.rename-keys") assert rename_keys is not _default if rename_keys is True: key_renamer = default_fused_keys_renamer elif rename_keys is False: key_renamer = None elif not callable(rename_keys): raise TypeError("rename_keys must be a boolean or callable") else: key_renamer = rename_keys rename_keys = key_renamer is not None if dependencies is None: deps = {k: get_dependencies(dsk, k, as_list=True) for k in dsk} else: deps = { k: v if isinstance(v, list) else get_dependencies(dsk, k, as_list=True) for k, v in dependencies.items() } rdeps = {} for k, vals in deps.items(): for v in vals: if v not in rdeps: rdeps[v] = [k] else: rdeps[v].append(k) deps[k] = set(vals) reducible = set() for k, vals in rdeps.items(): if ( len(vals) == 1 and k not in (keys or ()) and k in dsk and not isinstance(dsk[k], GraphNode) and (type(dsk[k]) is tuple or isinstance(dsk[k], (numbers.Number, str))) and not any(isinstance(dsk[v], GraphNode) for v in vals) ): reducible.add(k) if not reducible and ( not fuse_subgraphs or all(len(set(v)) != 1 for v in rdeps.values()) ): # Quick return if there's nothing to do. Only progress if there's tasks # fusible by the main `fuse`, or by `fuse_subgraphs` if enabled. return dsk, deps rv = dsk.copy() fused_trees = {} # These are the stacks we use to store data as we traverse the graph info_stack = [] children_stack = [] # For speed deps_pop = deps.pop reducible_add = reducible.add reducible_pop = reducible.pop reducible_remove = reducible.remove fused_trees_pop = fused_trees.pop info_stack_append = info_stack.append info_stack_pop = info_stack.pop children_stack_append = children_stack.append children_stack_extend = children_stack.extend children_stack_pop = children_stack.pop while reducible: parent = reducible_pop() reducible_add(parent) while parent in reducible: # Go to the top parent = rdeps[parent][0] children_stack_append(parent) children_stack_extend(reducible & deps[parent]) while True: child = children_stack[-1] if child != parent: children = reducible & deps[child] while children: # Depth-first search children_stack_extend(children) parent = child child = children_stack[-1] children = reducible & deps[child] children_stack_pop() # This is a leaf node in the reduction region # key, task, fused_keys, height, width, number of nodes, fudge, set of edges info_stack_append( ( child, rv[child], [child] if rename_keys else None, 1, 1, 1, 0, deps[child] - reducible, ) ) else: children_stack_pop() # Calculate metrics and fuse as appropriate deps_parent = deps[parent] edges = deps_parent - reducible children = deps_parent - edges num_children = len(children) if num_children == 1: ( child_key, child_task, child_keys, height, width, num_nodes, fudge, children_edges, ) = info_stack_pop() num_children_edges = len(children_edges) if fudge > num_children_edges - 1 >= 0: fudge = num_children_edges - 1 edges |= children_edges no_new_edges = len(edges) == num_children_edges if not no_new_edges: fudge += 1 if ( (num_nodes + fudge) / height <= ave_width and # Sanity check; don't go too deep if new levels introduce new edge dependencies (no_new_edges or height < max_depth_new_edges) and ( not isinstance(dsk[parent], GraphNode) # TODO: substitute can be implemented with GraphNode.inline # or isinstance(dsk[child_key], GraphNode) ) ): # Perform substitutions as we go val = subs(dsk[parent], child_key, child_task) deps_parent.remove(child_key) deps_parent |= deps_pop(child_key) del rv[child_key] reducible_remove(child_key) if rename_keys: child_keys.append(parent) fused_trees[parent] = child_keys fused_trees_pop(child_key, None) if children_stack: if no_new_edges: # Linear fuse info_stack_append( ( parent, val, child_keys, height, width, num_nodes, fudge, edges, ) ) else: info_stack_append( ( parent, val, child_keys, height + 1, width, num_nodes + 1, fudge, edges, ) ) else: rv[parent] = val break else: rv[child_key] = child_task reducible_remove(child_key) if children_stack: # Allow the parent to be fused, but only under strict circumstances. # Ensure that linear chains may still be fused. if fudge > int(ave_width - 1): fudge = int(ave_width - 1) # This task *implicitly* depends on `edges` info_stack_append( ( parent, rv[parent], [parent] if rename_keys else None, 1, width, 1, fudge, edges, ) ) else: break else: child_keys = [] height = 1 width = 0 num_single_nodes = 0 num_nodes = 0 fudge = 0 children_edges = set() max_num_edges = 0 children_info = info_stack[-num_children:] del info_stack[-num_children:] for ( _, _, _, cur_height, cur_width, cur_num_nodes, cur_fudge, cur_edges, ) in children_info: if cur_height == 1: num_single_nodes += 1 elif cur_height > height: height = cur_height width += cur_width num_nodes += cur_num_nodes fudge += cur_fudge if len(cur_edges) > max_num_edges: max_num_edges = len(cur_edges) children_edges |= cur_edges # Fudge factor to account for possible parallelism with the boundaries num_children_edges = len(children_edges) fudge += min( num_children - 1, max(0, num_children_edges - max_num_edges) ) if fudge > num_children_edges - 1 >= 0: fudge = num_children_edges - 1 edges |= children_edges no_new_edges = len(edges) == num_children_edges if not no_new_edges: fudge += 1 if ( (num_nodes + fudge) / height <= ave_width and num_single_nodes <= ave_width and width <= max_width and height <= max_height and # Sanity check; don't go too deep if new levels introduce new edge dependencies (no_new_edges or height < max_depth_new_edges) and ( not isinstance(dsk[parent], GraphNode) and not any( isinstance(dsk[child_key], GraphNode) for child_key in children ) # TODO: substitute can be implemented with GraphNode.inline # or all( # isintance(dsk[child], GraphNode) for child in children ) ): # Perform substitutions as we go val = dsk[parent] children_deps = set() for child_info in children_info: cur_child = child_info[0] val = subs(val, cur_child, child_info[1]) del rv[cur_child] children_deps |= deps_pop(cur_child) reducible_remove(cur_child) if rename_keys: fused_trees_pop(cur_child, None) child_keys.extend(child_info[2]) deps_parent -= children deps_parent |= children_deps if rename_keys: child_keys.append(parent) fused_trees[parent] = child_keys if children_stack: info_stack_append( ( parent, val, child_keys, height + 1, width, num_nodes + 1, fudge, edges, ) ) else: rv[parent] = val break else: for child_info in children_info: rv[child_info[0]] = child_info[1] reducible_remove(child_info[0]) if children_stack: # Allow the parent to be fused, but only under strict circumstances. # Ensure that linear chains may still be fused. if width > max_width: width = max_width if fudge > int(ave_width - 1): fudge = int(ave_width - 1) # key, task, height, width, number of nodes, fudge, set of edges # This task *implicitly* depends on `edges` info_stack_append( ( parent, rv[parent], [parent] if rename_keys else None, 1, width, 1, fudge, edges, ) ) else: break # Traverse upwards parent = rdeps[parent][0] if fuse_subgraphs: _inplace_fuse_subgraphs(rv, keys, deps, fused_trees, rename_keys) if key_renamer: for root_key, fused_keys in fused_trees.items(): alias = key_renamer(fused_keys) if alias is not None and alias not in rv: rv[alias] = rv[root_key] rv[root_key] = alias deps[alias] = deps[root_key] deps[root_key] = {alias} return rv, deps
def _inplace_fuse_subgraphs(dsk, keys, dependencies, fused_trees, rename_keys): """Subroutine of fuse. Mutates dsk, dependencies, and fused_trees inplace""" # locate all members of linear chains child2parent = {} unfusible = set() for parent in dsk: deps = dependencies[parent] has_many_children = len(deps) > 1 for child in deps: if keys is not None and child in keys: unfusible.add(child) elif child in child2parent: del child2parent[child] unfusible.add(child) elif has_many_children: unfusible.add(child) elif child not in unfusible: child2parent[child] = parent # construct the chains from ancestor to descendant chains = [] parent2child = {v: k for k, v in child2parent.items()} while child2parent: child, parent = child2parent.popitem() chain = [child, parent] while parent in child2parent: parent = child2parent.pop(parent) del parent2child[parent] chain.append(parent) chain.reverse() while child in parent2child: child = parent2child.pop(child) del child2parent[child] chain.append(child) # Skip chains with < 2 executable tasks ntasks = 0 for key in chain: ntasks += istask(dsk[key]) if ntasks > 1: chains.append(chain) break # Mutate dsk fusing chains into subgraphs for chain in chains: subgraph = {k: dsk[k] for k in chain} outkey = chain[0] # Update dependencies and graph inkeys_set = dependencies[outkey] = dependencies[chain[-1]] for k in chain[1:]: del dependencies[k] del dsk[k] # Create new task inkeys = tuple(inkeys_set) dsk[outkey] = (SubgraphCallable(subgraph, outkey, inkeys),) + inkeys # Mutate `fused_trees` if key renaming is needed (renaming done in fuse) if rename_keys: chain2 = [] for k in chain: subchain = fused_trees.pop(k, False) if subchain: chain2.extend(subchain) else: chain2.append(k) fused_trees[outkey] = chain2 class SubgraphCallable: """Create a callable object from a dask graph. Parameters ---------- dsk : dict A dask graph outkey : Dask key The output key from the graph inkeys : list A list of keys to be used as arguments to the callable. name : str, optional The name to use for the function. """ dsk: Graph outkey: Key inkeys: tuple[Key, ...] name: str __slots__ = tuple(__annotations__) def __init__( self, dsk: Graph, outkey: Key, inkeys: Iterable[Key], name: str | None = None ): self.dsk = dsk self.outkey = outkey self.inkeys = tuple(inkeys) if name is None: name = "subgraph_callable-" + tokenize(dsk, outkey, self.inkeys) self.name = name def __repr__(self) -> str: return self.name def __eq__(self, other: Any) -> bool: return ( type(self) is type(other) and self.name == other.name and self.outkey == other.outkey and set(self.inkeys) == set(other.inkeys) ) def __call__(self, *args: Any) -> Any: if not len(args) == len(self.inkeys): raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args))) return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args))) def __reduce__(self) -> tuple: return SubgraphCallable, (self.dsk, self.outkey, self.inkeys, self.name) def __hash__(self) -> int: return hash((self.outkey, frozenset(self.inkeys), self.name)) def __dask_tokenize__(self) -> object: return ( "SubgraphCallable", normalize_token(self.dsk), self.outkey, self.inkeys, self.name, )