ray.data.from_tf#

ray.data.from_tf(dataset: tf.data.Dataset) MaterializedDataset[源代码]#

TensorFlow 数据集 创建一个 Dataset

这个函数效率低下。用于读取小数据集或原型设计。

警告

如果你的数据集很大,这个函数可能会执行缓慢或引发内存不足错误。为了避免问题,请使用类似 read_images() 的函数读取底层数据。

备注

此函数未并行化。它在将数据移动到分布式对象存储之前,会将整个数据集加载到本地节点的内存中。

示例

>>> import ray
>>> import tensorflow_datasets as tfds
>>> dataset, _ = tfds.load('cifar10', split=["train", "test"])  
>>> ds = ray.data.from_tf(dataset)  
>>> ds  
MaterializedDataset(
    num_blocks=...,
    num_rows=50000,
    schema={
        id: binary,
        image: numpy.ndarray(shape=(32, 32, 3), dtype=uint8),
        label: int64
    }
)
>>> ds.take(1)  
[{'id': b'train_16399', 'image': array([[[143,  96,  70],
[141,  96,  72],
[135,  93,  72],
...,
[ 96,  37,  19],
[105,  42,  18],
[104,  38,  20]],
...,
[[195, 161, 126],
[187, 153, 123],
[186, 151, 128],
...,
[212, 177, 147],
[219, 185, 155],
[221, 187, 157]]], dtype=uint8), 'label': 7}]
参数:

dataset – 一个 TensorFlow 数据集

返回:

一个包含存储在 TensorFlow Dataset 中的样本的 MaterializedDataset