ray.data.DataIterator.iter_torch_batches#

DataIterator.iter_torch_batches(*, prefetch_batches: int = 1, batch_size: int | None = 256, dtypes: torch.dtype | Dict[str, torch.dtype] | None = None, device: str = 'auto', collate_fn: Callable[[Dict[str, numpy.ndarray]], CollatedData] | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None) Iterable[TorchBatchType][源代码]#

返回数据集上 Torch 张量的批量可迭代对象。

这个可迭代对象生成一个列张量的字典。如果你在张量转换(例如,转换数据类型)或批次格式中需要更多的灵活性,可以尝试直接使用 iter_batches()

示例

>>> import ray
>>> for batch in ray.data.range(
...     12,
... ).iterator().iter_torch_batches(batch_size=4):
...     print(batch)
{'id': tensor([0, 1, 2, 3])}
{'id': tensor([4, 5, 6, 7])}
{'id': tensor([ 8,  9, 10, 11])}

使用 collate_fn 来自定义张量批次的创建方式。

>>> from typing import Any, Dict
>>> import torch
>>> import numpy as np
>>> import ray
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
...     return torch.stack(
...         [torch.as_tensor(array) for array in batch.values()],
...         axis=1
...     )
>>> iterator = ray.data.from_items([
...     {"col_1": 1, "col_2": 2},
...     {"col_1": 3, "col_2": 4}]).iterator()
>>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
...     print(batch)
tensor([[1, 2],
        [3, 4]])

时间复杂度: O(1)

参数:
  • prefetch_batches – 要预取的批次数量,超过当前批次。如果设置为大于0,将使用一个单独的线程池来将对象获取到本地节点,格式化批次,并应用collate_fn。默认为1。

  • batch_size – 每个批次中的行数,或 None 以使用整个块作为批次(块可能包含不同数量的行)。如果 drop_lastFalse,则最后一个批次可能包含少于 batch_size 行。默认为 256。

  • dtypes – 创建的张量(s)的 Torch 数据类型(s);如果为 None,数据类型将从张量数据中推断。不能与 collate_fn 参数一起使用。

  • device – 张量应放置的设备。默认为 “auto”,当数据集传递给 Ray Train 且未提供 collate_fn 时,张量会被移动到适当的设备。否则,默认为 CPU。您不能将此参数与 collate_fn 一起使用。

  • collate_fn – 一个将 Numpy 批次转换为 PyTorch 张量批次的函数。当指定此参数时,用户应在 collate_fn 外部手动处理主机到设备的数据传输。这对于在数据被批处理后进一步处理数据非常有用。潜在的用例包括沿第一个维度以外的维度进行整理,填充不同长度的序列,或一般处理不同长度的张量批次。如果未提供,则使用默认的整理函数,该函数仅将 Numpy 数组的批次转换为 PyTorch 张量的批次。此 API 仍在实验阶段,可能会发生变化。您不能将此参数与 dtypesdevice 一起使用。

  • drop_last – 如果最后一个批次不完整,是否丢弃它。

  • local_shuffle_buffer_size – 如果非空,数据将使用本地内存中的随机洗牌缓冲区进行随机洗牌,并且此值将作为本地内存中随机洗牌缓冲区中必须存在的最小行数,以便生成一个批次。当没有更多的行可以添加到缓冲区时,缓冲区中剩余的行将被排空。此缓冲区大小必须大于或等于 batch_size,因此在使用本地洗牌时也必须指定 batch_size

  • local_shuffle_seed – 用于本地随机洗牌的种子。

返回:

一个遍历 Torch Tensor 批次的可迭代对象。

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。