使用图像#

使用 Ray Data,您可以轻松读取和转换大型图像数据集。

本指南向您展示如何:

读取图像#

Ray Data 可以从多种格式读取图像。

要查看支持的文件格式的完整列表,请参阅 输入/输出参考

要加载像JPEG文件这样的原始图像,请调用 read_images()

备注

read_images() 使用 PIL。有关支持的文件格式列表,请参阅 图像文件格式

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")

print(ds.schema())
Column  Type
------  ----
image   numpy.ndarray(shape=(32, 32, 3), dtype=uint8)

要加载存储为 NumPy 格式的图像,请调用 read_numpy()

import ray

ds = ray.data.read_numpy("s3://anonymous@air-example-data/cifar-10/images.npy")

print(ds.schema())
Column  Type
------  ----
data    numpy.ndarray(shape=(32, 32, 3), dtype=uint8)

图像数据集通常包含 tf.train.Example 消息,如下所示:

features {
    feature {
        key: "image"
        value {
            bytes_list {
                value: ...  # Raw image bytes
            }
        }
    }
    feature {
        key: "label"
        value {
            int64_list {
                value: 3
            }
        }
    }
}

要加载存储在此格式中的示例,请调用 read_tfrecords()。然后,调用 map() 来解码原始图像字节。

import io
from typing import Any, Dict
import numpy as np
from PIL import Image
import ray

def decode_bytes(row: Dict[str, Any]) -> Dict[str, Any]:
    data = row["image"]
    image = Image.open(io.BytesIO(data))
    row["image"] = np.array(image)
    return row

ds = (
    ray.data.read_tfrecords(
        "s3://anonymous@air-example-data/cifar-10/tfrecords"
    )
    .map(decode_bytes)
)

print(ds.schema())
Column  Type
------  ----
image   numpy.ndarray(shape=(32, 32, 3), dtype=uint8)
label   int64

要加载存储在 Parquet 文件中的图像数据,请调用 ray.data.read_parquet()

import ray

ds = ray.data.read_parquet("s3://anonymous@air-example-data/cifar-10/parquet")

print(ds.schema())
Column  Type
------  ----
image   numpy.ndarray(shape=(32, 32, 3), dtype=uint8)
label   int64

有关创建数据集的更多信息,请参阅 加载数据

转换图像#

要转换图像,请调用 map()map_batches()

from typing import Any, Dict
import numpy as np
import ray

def increase_brightness(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    batch["image"] = np.clip(batch["image"] + 4, 0, 255)
    return batch

ds = (
    ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")
    .map_batches(increase_brightness)
)

有关数据转换的更多信息,请参阅 数据转换

对图像进行推理#

要使用预训练模型进行推理,首先加载并转换您的数据。

from typing import Any, Dict
from torchvision import transforms
import ray

def transform_image(row: Dict[str, Any]) -> Dict[str, Any]:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32))
    ])
    row["image"] = transform(row["image"])
    return row

ds = (
    ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")
    .map(transform_image)
)

接下来,实现一个可调用的类,用于设置并调用你的模型。

import torch
from torchvision import models

class ImageClassifier:
    def __init__(self):
        weights = models.ResNet18_Weights.DEFAULT
        self.model = models.resnet18(weights=weights)
        self.model.eval()

    def __call__(self, batch):
        inputs = torch.from_numpy(batch["image"])
        with torch.inference_mode():
            outputs = self.model(inputs)
        return {"class": outputs.argmax(dim=1)}

最后,调用 Dataset.map_batches()

predictions = ds.map_batches(
    ImageClassifier,
    concurrency=2,
    batch_size=4
)
predictions.show(3)
{'class': 118}
{'class': 153}
{'class': 296}

有关执行推理的更多信息,请参阅 端到端:离线批量推理有状态变换

保存图像#

以PNG、Parquet和NumPy等格式保存图像。要查看所有支持的格式,请参阅 输入/输出参考

要将图像保存为图像文件,请调用 write_images()

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_images("/tmp/simple", column="image", file_format="png")

要在 Parquet 文件中保存图像,请调用 write_parquet()

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_parquet("/tmp/simple")

要将图像保存为NumPy文件,请调用 write_numpy()

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")

有关保存数据的更多信息,请参阅 保存数据