ray.data.Dataset.take_batch#
- Dataset.take_batch(batch_size: int = 20, *, batch_format: str | None = 'default') pyarrow.Table | pandas.DataFrame | Dict[str, numpy.ndarray] [源代码]#
从
Dataset
中返回最多batch_size
行作为一个批次。Ray Data 将批次表示为 NumPy 数组或 pandas DataFrame。您可以通过指定
batch_format
来配置批次类型。此方法对于检查
map_batches()
的输入很有用。警告
take_batch()
将最多batch_size
行移动到调用者的机器上。如果batch_size
很大,此方法可能会导致调用者出现 `OutOfMemory
错误。备注
此操作将触发对此数据集执行的延迟转换。
示例
>>> import ray >>> ds = ray.data.range(100) >>> ds.take_batch(5) {'id': array([0, 1, 2, 3, 4])}
时间复杂度:O(指定的批量大小)
- 参数:
batch_size – 返回的最大行数。
batch_format – 如果
"default"
或"numpy"
,批次是Dict[str, numpy.ndarray]
。如果"pandas"
,批次是pandas.DataFrame
。
- 返回:
数据集中最多
batch_size
行的一批数据。- 抛出:
ValueError – 如果数据集为空。