Keras 数据加载工具,位于 keras.utils
中,帮助你从磁盘上的原始数据转到可以有效训练模型的 tf.data.Dataset
对象。
这些加载工具可以与 预处理层 结合使用,以进一步转换你的输入数据集,然后再进行训练。
这里有一个快速示例:假设你有 10 个文件夹,每个文件夹包含来自不同类别的 10,000 张图像,你想训练一个将图像映射到其类别的分类器。
你的训练数据文件夹将如下所示:
training_data/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
等等。
你也可以有一个结构相同的验证数据文件夹 validation_data/
。
你可以简单地这样做:
import keras
train_ds = keras.utils.image_dataset_from_directory(
directory='training_data/',
labels='inferred',
label_mode='categorical',
batch_size=32,
image_size=(256, 256))
validation_ds = keras.utils.image_dataset_from_directory(
directory='validation_data/',
labels='inferred',
label_mode='categorical',
batch_size=32,
image_size=(256, 256))
model = keras.applications.Xception(
weights=None, input_shape=(256, 256, 3), classes=10)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.fit(train_ds, epochs=10, validation_data=validation_ds)