ray.data.Dataset.streaming_split#

Dataset.streaming_split(n: int, *, equal: bool = False, locality_hints: List[NodeIdStr] | None = None) List[DataIterator][源代码]#

返回 n数据迭代器,可用于并行读取数据集的不相交子集。

此方法是分布式训练中推荐的使用 数据集 的方式。

流式分割通过将此 Dataset 的执行委托给一个协调器角色来工作。协调器从执行的流中拉取块引用,并将这些块分配给 n 个输出迭代器。迭代器从协调器角色拉取块,以便在 next 时返回给调用者。

返回的迭代器也是可重复的;每次迭代都会触发 Dataset 的新执行。每次迭代开始时都有一个隐式的屏障,这意味着在迭代开始之前必须在所有迭代器上调用 next

警告

因为迭代器是从同一个 Dataset 执行中拉取块,如果一个迭代器落后了,其他迭代器可能会被阻塞。

备注

此操作将触发对此数据集执行的延迟转换。

示例

import ray

ds = ray.data.range(100)
it1, it2 = ds.streaming_split(2, equal=True)

并行地从迭代器中消费数据。

@ray.remote
def consume(it):
    for batch in it.iter_batches():
       pass

ray.get([consume.remote(it1), consume.remote(it2)])

你可以多次循环遍历迭代器(多个周期)。

@ray.remote
def train(it):
    NUM_EPOCHS = 2
    for _ in range(NUM_EPOCHS):
        for batch in it.iter_batches():
            pass

ray.get([train.remote(it1), train.remote(it2)])

以下远程函数调用会阻塞,等待对 it2 的读取开始。

ray.get(train.remote(it1))
参数:
  • n – 要返回的输出迭代器数量。

  • equal – 如果 True,每个输出迭代器将看到完全相同数量的行,必要时会丢弃数据。如果 False,某些迭代器可能会比其他迭代器看到稍多或稍少的行,但不会丢弃数据。

  • locality_hints – 指定与每个迭代器位置对应的节点ID。数据集将尝试根据迭代器输出位置最小化数据移动。此列表的长度必须为 n。您可以通过调用 ray.get_runtime_context().get_node_id() 获取任务或角色的当前节点ID。

返回:

输出迭代器拆分。这些迭代器是 Ray 可序列化的,可以自由传递给任何 Ray 任务或角色。

参见

Dataset.split()

streaming_split() 不同,split() 会在内存中具体化数据集。