代码示例 / 快速Keras食谱 / 创建 TFRecords

创建 TFRecords

作者: Dimitre Oliveira
创建日期: 2021/02/27
最后修改: 2023/12/20
描述: 将数据转换为 TFRecord 格式。

在 Colab 中查看 GitHub 源码


介绍

TFRecord 格式是一种存储二进制记录序列的简单格式。 将数据转换为 TFRecord 有许多优点,例如:

  • 更高效的存储:TFRecord 数据可能比原始数据占用更少的空间;它还可以分成多个文件。
  • 快速 I/O:TFRecord 格式可以通过并行 I/O 操作进行读取,这对于 TPUs 或多个主机非常有用。
  • 自包含文件:TFRecord 数据可以从单个源读取——例如, COCO2017 数据集最初将数据存储在两个文件夹中(“images”和“annotations”)。

TFRecord 数据格式的一个重要使用案例是 TPUs 的训练。首先,TPUs 足够快,能够从优化的 I/O 操作中获益。此外,TPUs 需要将数据存储在远程(例如 Google Cloud Storage)上,并且使用 TFRecord 格式可以更容易地加载数据,而无需批量下载。

如果您还使用 tf.data API,使用 TFRecord 格式的性能可以进一步提高。

在这个示例中,您将学习如何将不同类型(图像、文本和数字)的数据转换为 TFRecord。

参考


依赖关系

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import json
import pprint
import tensorflow as tf
import matplotlib.pyplot as plt

下载 COCO2017 数据集

我们将使用 COCO2017 数据集,因为它具有许多不同类型的特征,包括图像、浮点数据和列表。 它将作为一个很好的示例,展示如何将不同的特征编码为 TFRecord 格式。

该数据集有两组字段:图像和注释元数据。

图像是 JPG 文件的集合,元数据存储在一个 JSON 文件中,该文件根据 官方网站 包含以下属性:

id: int,
image_id: int,
category_id: int,
segmentation: RLE or [polygon], object segmentation mask
bbox: [x,y,width,height], object bounding box coordinates
area: float, area of the bounding box
iscrowd: 0 or 1, is single object or a collection
root_dir = "datasets"
tfrecords_dir = "tfrecords"
images_dir = os.path.join(root_dir, "val2017")
annotations_dir = os.path.join(root_dir, "annotations")
annotation_file = os.path.join(annotations_dir, "instances_val2017.json")
images_url = "http://images.cocodataset.org/zips/val2017.zip"
annotations_url = (
    "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
)

# Download image files
if not os.path.exists(images_dir):
    image_zip = keras.utils.get_file(
        "images.zip",
        cache_dir=os.path.abspath("."),
        origin=images_url,
        extract=True,
    )
    os.remove(image_zip)

# Download caption annotation files
if not os.path.exists(annotations_dir):
    annotation_zip = keras.utils.get_file(
        "captions.zip",
        cache_dir=os.path.abspath("."),
        origin=annotations_url,
        extract=True,
    )
    os.remove(annotation_zip)

print("The COCO dataset has been downloaded and extracted successfully.")

with open(annotation_file, "r") as f:
    annotations = json.load(f)["annotations"]

print(f"Number of images: {len(annotations)}")
Downloading data from http://images.cocodataset.org/zips/val2017.zip
 815585330/815585330 ━━━━━━━━━━━━━━━━━━━━ 79s 0us/step
Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2017.zip
 252907541/252907541 ━━━━━━━━━━━━━━━━━━━━ 5s 0us/step
The COCO dataset has been downloaded and extracted successfully.
Number of images: 36781

Contents of the COCO2017 dataset

pprint.pprint(annotations[60])
{'area': 367.89710000000014,
 'bbox': [265.67, 222.31, 26.48, 14.71],
 'category_id': 72,
 'id': 34096,
 'image_id': 525083,
 'iscrowd': 0,
 'segmentation': [[267.51,
                   222.31,
                   292.15,
                   222.31,
                   291.05,
                   237.02,
                   265.67,
                   237.02]]}
--- ## 参数 `num_samples` 是每个 TFRecord 文件中的数据样本数量。 `num_tfrecords` 是我们将创建的 TFRecords 的总数。
num_samples = 4096
num_tfrecords = len(annotations) // num_samples
if len(annotations) % num_samples:
    num_tfrecords += 1  # 如果有剩余样本,则添加一个记录

if not os.path.exists(tfrecords_dir):
    os.makedirs(tfrecords_dir)  # 创建 TFRecords 输出文件夹
--- ## 定义 TFRecords 辅助函数
def image_feature(value):
    """从字符串/字节返回 bytes_list。"""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])
    )


def bytes_feature(value):
    """从字符串/字节返回 bytes_list。"""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))


def float_feature(value):
    """从浮点数/双精度返回 float_list。"""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
    """从布尔值/枚举/整数/无符号整数返回 int64_list。"""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature_list(value):
    """从浮点数/双精度返回 float_list 的列表。"""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(image, path, example):
    feature = {
        "image": image_feature(image),
        "path": bytes_feature(path),
        "area": float_feature(example["area"]),
        "bbox": float_feature_list(example["bbox"]),
        "category_id": int64_feature(example["category_id"]),
        "id": int64_feature(example["id"]),
        "image_id": int64_feature(example["image_id"]),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "path": tf.io.FixedLenFeature([], tf.string),
        "area": tf.io.FixedLenFeature([], tf.float32),
        "bbox": tf.io.VarLenFeature(tf.float32),
        "category_id": tf.io.FixedLenFeature([], tf.int64),
        "id": tf.io.FixedLenFeature([], tf.int64),
        "image_id": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    example["bbox"] = tf.sparse.to_dense(example["bbox"])
    return example
--- ## 生成 TFRecord 格式的数据 让我们生成 COCO2017 数据以 TFRecord 格式存储。格式将是 `file_{number}.tfrec`(这不是必须的,但将数字序列包含在文件名中可以使计数更容易)。
for tfrec_num in range(num_tfrecords):
    samples = annotations[(tfrec_num * num_samples) : ((tfrec_num + 1) * num_samples)]

    with tf.io.TFRecordWriter(
        tfrecords_dir + "/file_%.2i-%i.tfrec" % (tfrec_num, len(samples))
    ) as writer:
        for sample in samples:
            image_path = f"{images_dir}/{sample['image_id']:012d}.jpg"
            image = tf.io.decode_jpeg(tf.io.read_file(image_path))
            example = create_example(image, image_path, sample)
            writer.write(example.SerializeToString())
--- ## 探索从生成的 TFRecord 中的一个样本
raw_dataset = tf.data.TFRecordDataset(f"{tfrecords_dir}/file_00-{num_samples}.tfrec")
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)

for features in parsed_dataset.take(1):
    for key in features.keys():
        if key != "image":
            print(f"{key}: {features[key]}")

    print(f"图像形状: {features['image'].shape}")
    plt.figure(figsize=(7, 7))
    plt.imshow(features["image"].numpy())
    plt.show()
bbox: [473.07 395.93  38.65  28.67]
area: 702.1057739257812
category_id: 18
id: 1768
image_id: 289343
path: b'datasets/val2017/000000289343.jpg'
图像形状: (640, 529, 3)
![png](/img/examples/keras_recipes/creating_tfrecords/creating_tfrecords_14_1.png) --- ## 使用生成的 TFRecords 训练简单模型 TFRecord 的另一个优点是您可以向其添加许多特征,稍后只使用其中的一些,在这种情况下,我们将只使用 `image` 和 `category_id`。 --- ## 定义数据集辅助函数
def prepare_sample(features):
    image = keras.ops.image.resize(features["image"], size=(224, 224))
    return image, features["category_id"]


def get_dataset(filenames, batch_size):
    dataset = (
        tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
        .map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
        .map(prepare_sample, num_parallel_calls=AUTOTUNE)
        .shuffle(batch_size * 10)
        .batch(batch_size)
        .prefetch(AUTOTUNE)
    )
    return dataset


train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")
batch_size = 32
epochs = 1
steps_per_epoch = 50
AUTOTUNE = tf.data.AUTOTUNE

input_tensor = keras.layers.Input(shape=(224, 224, 3), name="image")
model = keras.applications.EfficientNetB0(
    input_tensor=input_tensor, weights=None, classes=91
)


model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)


model.fit(
    x=get_dataset(train_filenames, batch_size),
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    verbose=1,
)
 50/50 ━━━━━━━━━━━━━━━━━━━━ 146s 2s/step - loss: 3.9206 - sparse_categorical_accuracy: 0.1690

<keras.src.callbacks.history.History at 0x7f70684c27a0>
--- ## 结论 本示例演示了如何通过 TFRecord 来简化数据的存储和读取,避免了从不同来源读取图像和注释。这个过程可以使数据的存储和读取更加简单和高效。有关更多信息,请参阅 [TFRecord 和 tf.train.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord) 教程。