ray.data.Dataset.take_batch#

Dataset.take_batch(batch_size: int = 20, *, batch_format: str | None = 'default') pyarrow.Table | pandas.DataFrame | Dict[str, numpy.ndarray][源代码]#

Dataset 中返回最多 batch_size 行作为一个批次。

Ray Data 将批次表示为 NumPy 数组或 pandas DataFrame。您可以通过指定 batch_format 来配置批次类型。

此方法对于检查 map_batches() 的输入很有用。

警告

take_batch() 将最多 batch_size 行移动到调用者的机器上。如果 batch_size 很大,此方法可能会导致调用者出现 ` OutOfMemory 错误。

备注

此操作将触发对此数据集执行的延迟转换。

示例

>>> import ray
>>> ds = ray.data.range(100)
>>> ds.take_batch(5)
{'id': array([0, 1, 2, 3, 4])}

时间复杂度:O(指定的批量大小)

参数:
  • batch_size – 返回的最大行数。

  • batch_format – 如果 "default""numpy",批次是 Dict[str, numpy.ndarray]。如果 "pandas",批次是 pandas.DataFrame

返回:

数据集中最多 batch_size 行的一批数据。

抛出:

ValueError – 如果数据集为空。