使用张量 / NumPy#

N 维数组(换句话说,张量)在机器学习工作负载中无处不在。本指南描述了使用此类数据的限制和最佳实践。

张量数据表示#

Ray Data 将张量表示为 NumPy ndarrays

import ray

ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
print(ds)
Dataset(
   num_rows=100,
   schema={image: numpy.ndarray(shape=(28, 28), dtype=uint8)}
)

固定形状的张量批次#

如果你的张量具有固定形状,Ray Data 会将批次表示为常规的 ndarrays。

>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32, 28, 28)
>>> batch["image"].dtype
dtype('uint8')

可变形状张量的批次#

如果你的张量形状不同,Ray Data 会将批次表示为对象数据类型的数组。

>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32,)
>>> batch["image"].dtype
dtype('O')

这些对象数组的各个元素是常规的 ndarrays。

>>> batch["image"][0].dtype
dtype('uint8')
>>> batch["image"][0].shape  
(375, 500, 3)
>>> batch["image"][3].shape  
(333, 465, 3)

转换张量数据#

调用 map()map_batches() 来转换张量数据。

from typing import Any, Dict

import ray
import numpy as np

ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")

def increase_brightness(row: Dict[str, Any]) -> Dict[str, Any]:
    row["image"] = np.clip(row["image"] + 4, 0, 255)
    return row

# Increase the brightness, record at a time.
ds.map(increase_brightness)

def batch_increase_brightness(batch: Dict[str, np.ndarray]) -> Dict:
    batch["image"] = np.clip(batch["image"] + 4, 0, 255)
    return batch

# Increase the brightness, batch at a time.
ds.map_batches(batch_increase_brightness)

除了 NumPy ndarrays 之外,Ray Data 还将返回的 NumPy ndarrays 列表和实现 __array__ 的对象(例如,torch.Tensor)视为张量数据。

有关数据转换的更多信息,请阅读 数据转换

保存张量数据#

使用 Parquet、NumPy 和 JSON 等格式保存张量数据。要查看所有支持的格式,请参阅 输入/输出参考

调用 write_parquet() 以将数据保存为 Parquet 文件。

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_parquet("/tmp/simple")

调用 write_numpy() 以将 ndarray 列保存为 NumPy 文件。

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")

要在 JSON 文件中保存图像,请调用 write_json()

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_json("/tmp/simple")

有关保存数据的更多信息,请阅读 保存数据