使用张量 / NumPy#
N 维数组(换句话说,张量)在机器学习工作负载中无处不在。本指南描述了使用此类数据的限制和最佳实践。
张量数据表示#
Ray Data 将张量表示为 NumPy ndarrays。
import ray
ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
print(ds)
Dataset(
num_rows=100,
schema={image: numpy.ndarray(shape=(28, 28), dtype=uint8)}
)
固定形状的张量批次#
如果你的张量具有固定形状,Ray Data 会将批次表示为常规的 ndarrays。
>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32, 28, 28)
>>> batch["image"].dtype
dtype('uint8')
可变形状张量的批次#
如果你的张量形状不同,Ray Data 会将批次表示为对象数据类型的数组。
>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32,)
>>> batch["image"].dtype
dtype('O')
这些对象数组的各个元素是常规的 ndarrays。
>>> batch["image"][0].dtype
dtype('uint8')
>>> batch["image"][0].shape
(375, 500, 3)
>>> batch["image"][3].shape
(333, 465, 3)
转换张量数据#
调用 map()
或 map_batches()
来转换张量数据。
from typing import Any, Dict
import ray
import numpy as np
ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")
def increase_brightness(row: Dict[str, Any]) -> Dict[str, Any]:
row["image"] = np.clip(row["image"] + 4, 0, 255)
return row
# Increase the brightness, record at a time.
ds.map(increase_brightness)
def batch_increase_brightness(batch: Dict[str, np.ndarray]) -> Dict:
batch["image"] = np.clip(batch["image"] + 4, 0, 255)
return batch
# Increase the brightness, batch at a time.
ds.map_batches(batch_increase_brightness)
除了 NumPy ndarrays 之外,Ray Data 还将返回的 NumPy ndarrays 列表和实现 __array__
的对象(例如,torch.Tensor
)视为张量数据。
有关数据转换的更多信息,请阅读 数据转换。
保存张量数据#
使用 Parquet、NumPy 和 JSON 等格式保存张量数据。要查看所有支持的格式,请参阅 输入/输出参考。
调用 write_parquet()
以将数据保存为 Parquet 文件。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_parquet("/tmp/simple")
调用 write_numpy()
以将 ndarray 列保存为 NumPy 文件。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")
要在 JSON 文件中保存图像,请调用 write_json()
。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_json("/tmp/simple")
有关保存数据的更多信息,请阅读 保存数据。