ray.data.from_torch#

ray.data.from_torch(dataset: torch.utils.data.Dataset, local_read: bool = False) Dataset[源代码]#

Torch 数据集 创建一个 Dataset

备注

输入数据集可以是映射样式或可迭代样式,并且可以包含任意大量的数据。数据将以单个读取任务顺序流式传输。

示例

>>> import ray
>>> from torchvision import datasets
>>> dataset = datasets.MNIST("data", download=True)  
>>> ds = ray.data.from_torch(dataset)  
>>> ds  
MaterializedDataset(num_blocks=..., num_rows=60000, schema={item: object})
>>> ds.take(1)  
{"item": (<PIL.Image.Image image mode=L size=28x28 at 0x...>, 5)}
参数:
  • dataset – 一个 Torch 数据集

  • local_read – 如果 True ,执行本地读取。

返回:

包含 Torch 数据集样本的 Dataset