作者: Yixing Fu
创建日期: 2020/06/30
最后修改日期: 2023/07/10
描述: 使用在 imagenet 上预训练权重的 EfficientNet 进行斯坦福狗分类。
EfficientNet,最早在 Tan and Le, 2019 中提出,是最有效的模型之一(即在推理时需要最少的 FLOPS),在 imagenet 和常见的图像分类迁移学习任务上达到了最先进的准确率。
最小的基础模型类似于 MnasNet,其模型显著更小且达到了接近 SOTA 的表现。通过引入一种启发式的方法来缩放模型,EfficientNet 提供了一系列模型(B0 到 B7),在不同规模上提供了效率和准确率的良好结合。这种缩放启发式(复合缩放,详细见 Tan and Le, 2019)使得以效率为导向的基础模型(B0)能够超越每个规模的模型,同时避免了对超参数的大规模网格搜索。
可在 这里 获取有关模型最新更新的摘要,其中应用了各种增强方案和半监督学习方法,以进一步提高模型在 imagenet 上的性能。这些模型的扩展可以通过更新权重而不改变模型架构来使用。
(本节提供有关“复合缩放”的一些细节,如果您只对使用模型感兴趣,可以跳过此节)
根据 原始论文,人们可能会觉得 EfficientNet 是一个连续的模型系列,由任意选择缩放因子构成,如论文的公式 (3) 中所示。然而,分辨率、深度和宽度的选择也受到许多因素的限制:
因此,EfficientNet 模型每个变体的深度、宽度和分辨率都是精心挑选并证明能产生良好结果,尽管它们可能与复合缩放公式显著偏离。因此,keras 的实现(如下所述)仅提供这 8 个模型 B0 到 B7,而不允许任意选择宽度 / 深度 / 分辨率参数。
自 v2.3 以来,EfficientNet B0 到 B7 的实现已与 Keras 一起发布。要使用 EfficientNetB0 对 ImageNet 的 1000 个类别的图像进行分类,可以运行:
from tensorflow.keras.applications import EfficientNetB0
model = EfficientNetB0(weights='imagenet')
该模型接受形状为 (224, 224, 3)
的输入图像,输入数据应在 [0, 255]
的范围内。归一化作为模型的一部分。
由于在 ImageNet 上训练 EfficientNet 需要大量资源和几个不属于模型架构本身的技术。因此,Keras 的实现默认加载通过使用 AutoAugment 训练获得的预训练权重。
对于 B0 到 B7 基础模型,输入形状是不同的。以下是每个模型期望的输入形状列表:
基础模型 | 分辨率 |
---|---|
EfficientNetB0 | 224 |
EfficientNetB1 | 240 |
EfficientNetB2 | 260 |
EfficientNetB3 | 300 |
EfficientNetB4 | 380 |
EfficientNetB5 | 456 |
EfficientNetB6 | 528 |
EfficientNetB7 | 600 |
当模型用于迁移学习时,Keras 实现提供了一个选项来移除顶层:
model = EfficientNetB0(include_top=False, weights='imagenet')
此选项排除了最终的 Dense
层,该层将倒数第二层的 1280 个特征转换为 1000 个 ImageNet 类的预测。用自定义层替换顶层允许在迁移学习工作流程中使用 EfficientNet 作为特征提取器。
模型构造函数中另一个值得注意的参数是 drop_connect_rate
,它控制负责 随机深度 的 dropout 率。该参数作为细调中额外正则化的切换,但不影响加载的权重。例如,当希望更强的正则化时,可以尝试:
model = EfficientNetB0(weights='imagenet', drop_connect_rate=0.4)
EfficientNet能够处理广泛的图像分类任务。 这使得它成为迁移学习的良好模型。 作为一个端到端的示例,我们将展示如何在 斯坦福犬 数据集上使用预训练的EfficientNetB0。
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf # 用于 tf.data
import matplotlib.pyplot as plt
import keras
from keras import layers
from keras.applications import EfficientNetB0
# IMG_SIZE 由 EfficientNet 模型选择确定
IMG_SIZE = 224
BATCH_SIZE = 64
在这里,我们从 tensorflow_datasets (以下简称 TFDS)加载数据。 斯坦福犬数据集在 TFDS 中提供为 stanford_dogs。 它包含 20,580 张图像,属于 120 个犬种类别 (12,000 张用于训练,8,580 张用于测试)。
只需更改下面的 dataset_name
,您还可以尝试在 TFDS 中的其他数据集,例如
cifar10、
cifar100、
food101 等。当图像的大小远小于EfficientNet输入的大小时,
我们可以简单地对输入图像进行上采样。
Tan 和 Le, 2019 的研究表明,即使输入图像保持较小,增加分辨率的迁移学习结果也会更好。
dataset_name = "stanford_dogs"
(ds_train, ds_test), ds_info = tfds.load(
dataset_name, split=["train", "test"], with_info=True, as_supervised=True
)
NUM_CLASSES = ds_info.features["label"].num_classes
当数据集包含各种尺寸的图像时,我们需要将它们调整为共享的尺寸。斯坦福狗数据集仅包含至少 200x200 像素的图像。在这里,我们将图像调整为 EfficientNet 所需的输入尺寸。
size = (IMG_SIZE, IMG_SIZE)
ds_train = ds_train.map(lambda image, label: (tf.image.resize(image, size), label))
ds_test = ds_test.map(lambda image, label: (tf.image.resize(image, size), label))
以下代码显示前 9 张图像及其标签。
def format_label(label):
string_label = label_info.int2str(label)
return string_label.split("-")[1]
label_info = ds_info.features["label"]
for i, (image, label) in enumerate(ds_train.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("uint8"))
plt.title("{}".format(format_label(label)))
plt.axis("off")
我们可以使用预处理层 API 进行图像增强。
img_augmentation_layers = [
layers.RandomRotation(factor=0.15),
layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
layers.RandomFlip(),
layers.RandomContrast(factor=0.1),
]
def img_augmentation(images):
for layer in img_augmentation_layers:
images = layer(images)
return images
这个 Sequential
模型对象可以作为我们随后构建的模型的一部分,也可以作为在输入模型之前预处理数据的函数。将它们作为函数使用可以轻松可视化增强的图像。这里我们绘制了给定图像的 9 个增强结果示例。
for image, label in ds_train.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
aug_img = img_augmentation(np.expand_dims(image.numpy(), axis=0))
aug_img = np.array(aug_img)
plt.imshow(aug_img[0].astype("uint8"))
plt.title("{}".format(format_label(label)))
plt.axis("off")
一旦我们验证输入数据和增强效果正常,我们就准备训练数据集。输入数据被调整为统一的 IMG_SIZE
。标签被转换为独热编码(即分类编码)。数据集被分批处理。
注意:prefetch
和 AUTOTUNE
在某些情况下可能提高性能,但取决于环境和特定的数据集。有关数据管道性能的更多信息,请参见这个 指南。
# 一热编码 / 类别编码
def input_preprocess_train(image, label):
image = img_augmentation(image)
label = tf.one_hot(label, NUM_CLASSES)
return image, label
def input_preprocess_test(image, label):
label = tf.one_hot(label, NUM_CLASSES)
return image, label
ds_train = ds_train.map(input_preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(input_preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True)
我们构建一个具有 120 个输出类别的 EfficientNetB0,该模型从头开始初始化:
注意:准确度将非常缓慢地增加并可能过拟合。
model = EfficientNetB0(
include_top=True,
weights=None,
classes=NUM_CLASSES,
input_shape=(IMG_SIZE, IMG_SIZE, 3),
)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.summary()
epochs = 40 # @param {type: "slider", min:10, max:100}
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
模型: "efficientnetb0"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ 连接到 ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer │ (None, 224, 224, │ 0 │ - │ │ (输入层) │ 3) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ rescaling │ (None, 224, 224, │ 0 │ input_layer[0][0] │ │ (重新缩放) │ 3) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ normalization │ (None, 224, 224, │ 7 │ rescaling[0][0] │ │ (标准化) │ 3) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ stem_conv_pad │ (None, 225, 225, │ 0 │ normalization[0][0] │ │ (ZeroPadding2D) │ 3) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ stem_conv (Conv2D) │ (None, 112, 112, │ 864 │ stem_conv_pad[0][0] │ │ │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ stem_bn │ (None, 112, 112, │ 128 │ stem_conv[0][0] │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ stem_activation │ (None, 112, 112, │ 0 │ stem_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_dwconv │ (None, 112, 112, │ 288 │ stem_activation[0][… │ │ (DepthwiseConv2D) │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_bn │ (None, 112, 112, │ 128 │ block1a_dwconv[0][0] │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_activation │ (None, 112, 112, │ 0 │ block1a_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_se_squeeze │ (无, 32) │ 0 │ block1a_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_se_reshape │ (无, 1, 1, 32) │ 0 │ block1a_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_se_reduce │ (无, 1, 1, 8) │ 264 │ block1a_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_se_expand │ (无, 1, 1, 32) │ 288 │ block1a_se_reduce[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_se_excite │ (无, 112, 112, │ 0 │ block1a_activation[… │ │ (乘法) │ 32) │ │ block1a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_project_co… │ (无, 112, 112, │ 512 │ block1a_se_excite[0… │ │ (卷积2D) │ 16) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block1a_project_bn │ (无, 112, 112, │ 64 │ block1a_project_con… │ │ (批标准化… │ 16) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_expand_conv │ (None, 112, 112, │ 1,536 │ block1a_project_bn[… │ │ (Conv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_expand_bn │ (None, 112, 112, │ 384 │ block2a_expand_conv… │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_expand_act… │ (None, 112, 112, │ 0 │ block2a_expand_bn[0… │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_dwconv_pad │ (None, 113, 113, │ 0 │ block2a_expand_acti… │ │ (ZeroPadding2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_dwconv │ (None, 56, 56, │ 864 │ block2a_dwconv_pad[… │ │ (DepthwiseConv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_bn │ (None, 56, 56, │ 384 │ block2a_dwconv[0][0] │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_activation │ (None, 56, 56, │ 0 │ block2a_bn[0][0] │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_se_squeeze │ (无, 96) │ 0 │ block2a_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_se_reshape │ (无, 1, 1, 96) │ 0 │ block2a_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_se_reduce │ (无, 1, 1, 4) │ 388 │ block2a_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_se_expand │ (无, 1, 1, 96) │ 480 │ block2a_se_reduce[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_se_excite │ (无, 56, 56, │ 0 │ block2a_activation[… │ │ (乘法) │ 96) │ │ block2a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_project_co… │ (无, 56, 56, │ 2,304 │ block2a_se_excite[0… │ │ (卷积2D) │ 24) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2a_project_bn │ (无, 56, 56, │ 96 │ block2a_project_con… │ │ (批量归一化… │ 24) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_expand_conv │ (None, 56, 56, │ 3,456 │ block2a_project_bn[… │ │ (Conv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_expand_bn │ (None, 56, 56, │ 576 │ block2b_expand_conv… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_expand_act… │ (None, 56, 56, │ 0 │ block2b_expand_bn[0… │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_dwconv │ (None, 56, 56, │ 1,296 │ block2b_expand_acti… │ │ (DepthwiseConv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_bn │ (None, 56, 56, │ 576 │ block2b_dwconv[0][0] │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_activation │ (None, 56, 56, │ 0 │ block2b_bn[0][0] │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_se_squeeze │ (None, 144) │ 0 │ block2b_activation[… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_se_reshape │ (None, 1, 1, 144) │ 0 │ block2b_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_se_reduce │ (None, 1, 1, 6) │ 870 │ block2b_se_reshape[… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_se_expand │ (None, 1, 1, 144) │ 1,008 │ block2b_se_reduce[0… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_se_excite │ (None, 56, 56, │ 0 │ block2b_activation[… │ │ (乘法) │ 144) │ │ block2b_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_project_co… │ (None, 56, 56, │ 3,456 │ block2b_se_excite[0… │ │ (卷积) │ 24) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_project_bn │ (None, 56, 56, │ 96 │ block2b_project_con… │ │ (批量归一化 │ 24) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_drop │ (None, 56, 56, │ 0 │ block2b_project_bn[… │ │ (丢弃) │ 24) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block2b_add (加法) │ (无, 56, 56, │ 0 │ block2b_drop[0][0], │ │ │ 24) │ │ block2a_project_bn[… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_expand_conv │ (无, 56, 56, │ 3,456 │ block2b_add[0][0] │ │ (卷积层) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_expand_bn │ (无, 56, 56, │ 576 │ block3a_expand_conv… │ │ (批归一化 │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_expand_act… │ (无, 56, 56, │ 0 │ block3a_expand_bn[0… │ │ (激活层) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_dwconv_pad │ (无, 59, 59, │ 0 │ block3a_expand_acti… │ │ (零填充) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_dwconv │ (无, 28, 28, │ 3,600 │ block3a_dwconv_pad[… │ │ (深度可分离卷积) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_bn │ (无, 28, 28, │ 576 │ block3a_dwconv[0][0] │ │ (批量归一化 │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_activation │ (无, 28, 28, │ 0 │ block3a_bn[0][0] │ │ (激活) │ 144) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_se_squeeze │ (无, 144) │ 0 │ block3a_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_se_reshape │ (无, 1, 1, 144) │ 0 │ block3a_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_se_reduce │ (无, 1, 1, 6) │ 870 │ block3a_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_se_expand │ (无, 1, 1, 144) │ 1,008 │ block3a_se_reduce[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_se_excite │ (无, 28, 28, │ 0 │ block3a_activation[… │ │ (相乘) │ 144) │ │ block3a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_project_co… │ (无, 28, 28, │ 5,760 │ block3a_se_excite[0… │ │ (Conv2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3a_project_bn │ (None, 28, 28, │ 160 │ block3a_project_con… │ │ (BatchNormalizatio… │ 40) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_expand_conv │ (None, 28, 28, │ 9,600 │ block3a_project_bn[… │ │ (Conv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_expand_bn │ (None, 28, 28, │ 960 │ block3b_expand_conv… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_expand_act… │ (None, 28, 28, │ 0 │ block3b_expand_bn[0… │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_dwconv │ (None, 28, 28, │ 6,000 │ block3b_expand_acti… │ │ (DepthwiseConv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_bn │ (None, 28, 28, │ 960 │ block3b_dwconv[0][0] │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_activation │ (None, 28, 28, │ 0 │ block3b_bn[0][0] │ │ (激活) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_se_squeeze │ (None, 240) │ 0 │ block3b_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_se_reshape │ (None, 1, 1, 240) │ 0 │ block3b_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_se_reduce │ (None, 1, 1, 10) │ 2,410 │ block3b_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_se_expand │ (None, 1, 1, 240) │ 2,640 │ block3b_se_reduce[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_se_excite │ (None, 28, 28, │ 0 │ block3b_activation[… │ │ (乘法) │ 240) │ │ block3b_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_project_co… │ (None, 28, 28, │ 9,600 │ block3b_se_excite[0… │ │ (卷积2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_project_bn │ (None, 28, 28, │ 160 │ block3b_project_con… │ │ (批量归一化… │ 40) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_drop │ (无, 28, 28, │ 0 │ block3b_project_bn[… │ │ (丢弃层) │ 40) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block3b_add (加法) │ (无, 28, 28, │ 0 │ block3b_drop[0][0], │ │ │ 40) │ │ block3a_project_bn[… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_expand_conv │ (无, 28, 28, │ 9,600 │ block3b_add[0][0] │ │ (卷积层) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_expand_bn │ (无, 28, 28, │ 960 │ block4a_expand_conv… │ │ (批量归一化… │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_expand_act… │ (无, 28, 28, │ 0 │ block4a_expand_bn[0… │ │ (激活层) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_dwconv_pad │ (无, 29, 29, │ 0 │ block4a_expand_acti… │ │ (零填充层) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_dwconv │ (无, 14, 14, │ 2,160 │ block4a_dwconv_pad[… │ │ (深度可分离卷积层) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_bn │ (无, 14, 14, │ 960 │ block4a_dwconv[0][0] │ │ (批量归一化 │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_activation │ (无, 14, 14, │ 0 │ block4a_bn[0][0] │ │ (激活函数) │ 240) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_se_squeeze │ (无, 240) │ 0 │ block4a_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_se_reshape │ (无, 1, 1, 240) │ 0 │ block4a_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_se_reduce │ (无, 1, 1, 10) │ 2,410 │ block4a_se_reshape[… │ │ (卷积层) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_se_expand │ (无, 1, 1, 240) │ 2,640 │ block4a_se_reduce[0… │ │ (卷积层) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_se_excite │ (无, 14, 14, │ 0 │ block4a_activation[… │ │ (乘法) │ 240) │ │ block4a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_project_co… │ (无, 14, 14, │ 19,200 │ block4a_se_excite[0… │ │ (卷积2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4a_project_bn │ (无, 14, 14, │ 320 │ block4a_project_con… │ │ (批量标准化 │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_expand_conv │ (无, 14, 14, │ 38,400 │ block4a_project_bn[… │ │ (卷积2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_expand_bn │ (无, 14, 14, │ 1,920 │ block4b_expand_conv… │ │ (批量标准化 │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_expand_act… │ (无, 14, 14, │ 0 │ block4b_expand_bn[0… │ │ (激活) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_dwconv │ (无, 14, 14, │ 4,320 │ block4b_expand_acti… │ │ (深度卷积2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_bn │ (None, 14, 14, │ 1,920 │ block4b_dwconv[0][0] │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_activation │ (None, 14, 14, │ 0 │ block4b_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_se_squeeze │ (None, 480) │ 0 │ block4b_activation[… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_se_reshape │ (None, 1, 1, 480) │ 0 │ block4b_se_squeeze[… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4b_se_reshape[… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4b_se_reduce[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_se_excite │ (None, 14, 14, │ 0 │ block4b_activation[… │ │ (Multiply) │ 480) │ │ block4b_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_project_co… │ (无, 14, 14, │ 38,400 │ block4b_se_excite[0… │ │ (卷积层) │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_project_bn │ (无, 14, 14, │ 320 │ block4b_project_con… │ │ (批标准化 │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_drop │ (无, 14, 14, │ 0 │ block4b_project_bn[… │ │ (丢弃) │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4b_add (相加) │ (无, 14, 14, │ 0 │ block4b_drop[0][0], │ │ │ 80) │ │ block4a_project_bn[… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_expand_conv │ (无, 14, 14, │ 38,400 │ block4b_add[0][0] │ │ (卷积层) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_expand_bn │ (无, 14, 14, │ 1,920 │ block4c_expand_conv… │ │ (批标准化 │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_expand_act… │ (无, 14, 14, │ 0 │ block4c_expand_bn[0… │ │ (激活) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_dwconv │ (None, 14, 14, │ 4,320 │ block4c_expand_acti… │ │ (深度可分离卷积) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_bn │ (None, 14, 14, │ 1,920 │ block4c_dwconv[0][0] │ │ (批归一化 │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_activation │ (None, 14, 14, │ 0 │ block4c_bn[0][0] │ │ (激活) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_se_squeeze │ (None, 480) │ 0 │ block4c_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_se_reshape │ (None, 1, 1, 480) │ 0 │ block4c_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4c_se_reshape[… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4c_se_reduce[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_se_excite │ (None, 14, 14, │ 0 │ block4c_activation[… │ │ (Multiply) │ 480) │ │ block4c_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_project_co… │ (None, 14, 14, │ 38,400 │ block4c_se_excite[0… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_project_bn │ (None, 14, 14, │ 320 │ block4c_project_con… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_drop │ (None, 14, 14, │ 0 │ block4c_project_bn[… │ │ (Dropout) │ 80) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block4c_add (Add) │ (None, 14, 14, │ 0 │ block4c_drop[0][0], │ │ │ 80) │ │ block4b_add[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_expand_conv │ (None, 14, 14, │ 38,400 │ block4c_add[0][0] │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_expand_bn │ (无, 14, 14, │ 1,920 │ block5a_expand_conv… │ │ (批量标准化… │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_expand_act… │ (无, 14, 14, │ 0 │ block5a_expand_bn[0… │ │ (激活) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_dwconv │ (无, 14, 14, │ 12,000 │ block5a_expand_acti… │ │ (深度卷积2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_bn │ (无, 14, 14, │ 1,920 │ block5a_dwconv[0][0] │ │ (批量标准化… │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_activation │ (无, 14, 14, │ 0 │ block5a_bn[0][0] │ │ (激活) │ 480) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_se_squeeze │ (无, 480) │ 0 │ block5a_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_se_reshape │ (无, 1, 1, 480) │ 0 │ block5a_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_se_reduce │ (无, 1, 1, 20) │ 9,620 │ block5a_se_reshape[… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_se_expand │ (无, 1, 1, 480) │ 10,080 │ block5a_se_reduce[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_se_excite │ (无, 14, 14, │ 0 │ block5a_activation[… │ │ (Multiply) │ 480) │ │ block5a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_project_co… │ (无, 14, 14, │ 53,760 │ block5a_se_excite[0… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5a_project_bn │ (无, 14, 14, │ 448 │ block5a_project_con… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_expand_conv │ (无, 14, 14, │ 75,264 │ block5a_project_bn[… │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_expand_bn │ (无, 14, 14, │ 2,688 │ block5b_expand_conv… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_expand_act… │ (无, 14, 14, │ 0 │ block5b_expand_bn[0… │ │ (激活函数) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_dwconv │ (无, 14, 14, │ 16,800 │ block5b_expand_acti… │ │ (深度卷积) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_bn │ (无, 14, 14, │ 2,688 │ block5b_dwconv[0][0] │ │ (批归一化 │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_activation │ (无, 14, 14, │ 0 │ block5b_bn[0][0] │ │ (激活函数) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_se_squeeze │ (无, 672) │ 0 │ block5b_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_se_reshape │ (无, 1, 1, 672) │ 0 │ block5b_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_se_reduce │ (无, 1, 1, 28) │ 18,844 │ block5b_se_reshape[… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_se_expand │ (无, 1, 1, 672) │ 19,488 │ block5b_se_reduce[0… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_se_excite │ (无, 14, 14, │ 0 │ block5b_activation[… │ │ (相乘) │ 672) │ │ block5b_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_project_co… │ (无, 14, 14, │ 75,264 │ block5b_se_excite[0… │ │ (卷积) │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_project_bn │ (无, 14, 14, │ 448 │ block5b_project_con… │ │ (批量归一化 │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_drop │ (无, 14, 14, │ 0 │ block5b_project_bn[… │ │ (丢弃) │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5b_add (相加) │ (无, 14, 14, │ 0 │ block5b_drop[0][0], │ │ │ 112) │ │ block5a_project_bn[… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_expand_conv │ (无, 14, 14, │ 75,264 │ block5b_add[0][0] │ │ (卷积层) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_expand_bn │ (无, 14, 14, │ 2,688 │ block5c_expand_conv… │ │ (批归一化 │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_expand_act… │ (无, 14, 14, │ 0 │ block5c_expand_bn[0… │ │ (激活层) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_dwconv │ (无, 14, 14, │ 16,800 │ block5c_expand_acti… │ │ (深度可分离卷积) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_bn │ (无, 14, 14, │ 2,688 │ block5c_dwconv[0][0] │ │ (批归一化 │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_activation │ (无, 14, 14, │ 0 │ block5c_bn[0][0] │ │ (激活层) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_se_squeeze │ (无, 672) │ 0 │ block5c_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_se_reshape │ (无, 1, 1, 672) │ 0 │ block5c_se_squeeze[… │ │ (重塑) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_se_reduce │ (无, 1, 1, 28) │ 18,844 │ block5c_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_se_expand │ (无, 1, 1, 672) │ 19,488 │ block5c_se_reduce[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_se_excite │ (无, 14, 14, │ 0 │ block5c_activation[… │ │ (乘法) │ 672) │ │ block5c_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_project_co… │ (无, 14, 14, │ 75,264 │ block5c_se_excite[0… │ │ (卷积2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_project_bn │ (无, 14, 14, │ 448 │ block5c_project_con… │ │ (批量归一化 │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_drop │ (无, 14, 14, │ 0 │ block5c_project_bn[… │ │ (丢弃) │ 112) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block5c_add (Add) │ (无, 14, 14, │ 0 │ block5c_drop[0][0], │ │ │ 112) │ │ block5b_add[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_expand_conv │ (无, 14, 14, │ 75,264 │ block5c_add[0][0] │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_expand_bn │ (无, 14, 14, │ 2,688 │ block6a_expand_conv… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_expand_act… │ (无, 14, 14, │ 0 │ block6a_expand_bn[0… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_dwconv_pad │ (无, 17, 17, │ 0 │ block6a_expand_acti… │ │ (ZeroPadding2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_dwconv │ (无, 7, 7, 672) │ 16,800 │ block6a_dwconv_pad[… │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_bn │ (无, 7, 7, 672) │ 2,688 │ block6a_dwconv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_activation │ (None, 7, 7, 672) │ 0 │ block6a_bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_se_squeeze │ (None, 672) │ 0 │ block6a_activation[… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_se_reshape │ (None, 1, 1, 672) │ 0 │ block6a_se_squeeze[… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block6a_se_reshape[… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_se_expand │ (None, 1, 1, 672) │ 19,488 │ block6a_se_reduce[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_se_excite │ (None, 7, 7, 672) │ 0 │ block6a_activation[… │ │ (Multiply) │ │ │ block6a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_project_co… │ (None, 7, 7, 192) │ 129,024 │ block6a_se_excite[0… │ │ (卷积层) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6a_project_bn │ (无, 7, 7, 192) │ 768 │ block6a_project_con… │ │ (批规范化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_expand_conv │ (无, 7, 7, │ 221,184 │ block6a_project_bn[… │ │ (卷积层) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_expand_bn │ (无, 7, 7, │ 4,608 │ block6b_expand_conv… │ │ (批规范化 │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_expand_act… │ (无, 7, 7, │ 0 │ block6b_expand_bn[0… │ │ (激活函数) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_dwconv │ (无, 7, 7, │ 28,800 │ block6b_expand_acti… │ │ (深度卷积层) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_bn │ (无, 7, 7, │ 4,608 │ block6b_dwconv[0][0] │ │ (批规范化 │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_activation │ (无, 7, 7, │ 0 │ block6b_bn[0][0] │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_se_squeeze │ (无, 1152) │ 0 │ block6b_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_se_reshape │ (无, 1, 1, │ 0 │ block6b_se_squeeze[… │ │ (重塑) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_se_reduce │ (无, 1, 1, 48) │ 55,344 │ block6b_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_se_expand │ (无, 1, 1, │ 56,448 │ block6b_se_reduce[0… │ │ (卷积2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_se_excite │ (无, 7, 7, │ 0 │ block6b_activation[… │ │ (乘法) │ 1152) │ │ block6b_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_project_co… │ (无, 7, 7, 192) │ 221,184 │ block6b_se_excite[0… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_project_bn │ (无, 7, 7, 192) │ 768 │ block6b_project_con… │ │ (批量标准化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_drop │ (无, 7, 7, 192) │ 0 │ block6b_project_bn[… │ │ (随机失活) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6b_add (相加) │ (无, 7, 7, 192) │ 0 │ block6b_drop[0][0], │ │ │ │ │ block6a_project_bn[… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_expand_conv │ (无, 7, 7, │ 221,184 │ block6b_add[0][0] │ │ (卷积2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_expand_bn │ (无, 7, 7, │ 4,608 │ block6c_expand_conv… │ │ (批量标准化… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_expand_act… │ (无, 7, 7, │ 0 │ block6c_expand_bn[0… │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_dwconv │ (无, 7, 7, │ 28,800 │ block6c_expand_acti… │ │ (深度卷积2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_bn │ (无, 7, 7, │ 4,608 │ block6c_dwconv[0][0] │ │ (批量标准化… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_activation │ (无, 7, 7, │ 0 │ block6c_bn[0][0] │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_se_squeeze │ (无, 1152) │ 0 │ block6c_activation[… │ │ (全局平均池化… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_se_reshape │ (无, 1, 1, │ 0 │ block6c_se_squeeze[… │ │ (重塑) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_se_reduce │ (无, 1, 1, 48) │ 55,344 │ block6c_se_reshape[… │ │ (卷积2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_se_expand │ (无, 1, 1, │ 56,448 │ block6c_se_reduce[0… │ │ (卷积2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_se_excite │ (无, 7, 7, │ 0 │ block6c_activation[… │ │ (乘法) │ 1152) │ │ block6c_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_project_co… │ (无, 7, 7, 192) │ 221,184 │ block6c_se_excite[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_project_bn │ (无, 7, 7, 192) │ 768 │ block6c_project_con… │ │ (批量标准化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_drop │ (无, 7, 7, 192) │ 0 │ block6c_project_bn[… │ │ (丢弃) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6c_add (相加) │ (无, 7, 7, 192) │ 0 │ block6c_drop[0][0], │ │ │ │ │ block6b_add[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_expand_conv │ (无, 7, 7, │ 221,184 │ block6c_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_expand_bn │ (无, 7, 7, │ 4,608 │ block6d_expand_conv… │ │ (批量标准化 │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_expand_act… │ (无, 7, 7, │ 0 │ block6d_expand_bn[0… │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_dwconv │ (无, 7, 7, │ 28,800 │ block6d_expand_acti… │ │ (深度卷积) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_bn │ (无, 7, 7, │ 4,608 │ block6d_dwconv[0][0] │ │ (批量归一化 │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_activation │ (无, 7, 7, │ 0 │ block6d_bn[0][0] │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_se_squeeze │ (无, 1152) │ 0 │ block6d_activation[… │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_se_reshape │ (无, 1, 1, │ 0 │ block6d_se_squeeze[… │ │ (重塑) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_se_reduce │ (无, 1, 1, 48) │ 55,344 │ block6d_se_reshape[… │ │ (卷积) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_se_expand │ (无, 1, 1, │ 56,448 │ block6d_se_reduce[0… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_se_excite │ (None, 7, 7, │ 0 │ block6d_activation[… │ │ (Multiply) │ 1152) │ │ block6d_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6d_se_excite[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_project_bn │ (None, 7, 7, 192) │ 768 │ block6d_project_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_drop │ (None, 7, 7, 192) │ 0 │ block6d_project_bn[… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block6d_add (Add) │ (None, 7, 7, 192) │ 0 │ block6d_drop[0][0], │ │ │ │ │ block6c_add[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_expand_conv │ (None, 7, 7, │ 221,184 │ block6d_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_expand_bn │ (无, 7, 7, │ 4,608 │ block7a_expand_conv… │ │ (批量归一化… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_expand_act… │ (无, 7, 7, │ 0 │ block7a_expand_bn[0… │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_dwconv │ (无, 7, 7, │ 10,368 │ block7a_expand_acti… │ │ (深度卷积2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_bn │ (无, 7, 7, │ 4,608 │ block7a_dwconv[0][0] │ │ (批量归一化… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_activation │ (无, 7, 7, │ 0 │ block7a_bn[0][0] │ │ (激活) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_se_squeeze │ (无, 1152) │ 0 │ block7a_activation[… │ │ (全局平均池… │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_se_reshape │ (无, 1, 1, │ 0 │ block7a_se_squeeze[… │ │ (重塑) │ 1152) │ │ │ │ block7a_se_reduce │ (无, 1, 1, 48) │ 55,344 │ block7a_se_reshape[… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_se_expand │ (无, 1, 1, │ 56,448 │ block7a_se_reduce[0… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_se_excite │ (无, 7, 7, │ 0 │ block7a_activation[… │ │ (乘法) │ 1152) │ │ block7a_se_expand[0… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_project_co… │ (无, 7, 7, 320) │ 368,640 │ block7a_se_excite[0… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ block7a_project_bn │ (无, 7, 7, 320) │ 1,280 │ block7a_project_con… │ │ (批归一化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ top_conv (Conv2D) │ (无, 7, 7, │ 409,600 │ block7a_project_bn[… │ │ │ 1280) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ top_bn │ (无, 7, 7, │ 5,120 │ top_conv[0][0] │ │ (批归一化 │ 1280) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ top_activation │ (无, 7, 7, │ 0 │ top_bn[0][0] │ │ (激活) │ 1280) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ avg_pool │ (无, 1280) │ 0 │ top_activation[0][0] │ │ (全局平均池化 │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ top_dropout │ (无, 1280) │ 0 │ avg_pool[0][0] │ │ (丢弃) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ predictions (密集) │ (无, 120) │ 153,720 │ top_dropout[0][0] │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘
总参数: 4,203,291 (16.03 MB)
可训练参数: 4,161,268 (15.87 MB)
非可训练参数: 42,023 (164.16 KB)
Epoch 1/40
1/187 ━━━━━━━━━━━━━━━━━━━━ 5:30:13 107s/step - 准确率: 0.0000e+00 - 损失: 5.1065
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700241724.682725 1549299 device_compiler.h:187] 使用XLA编译集群!这行日志在进程的生命周期内最多记录一次。
187/187 ━━━━━━━━━━━━━━━━━━━━ 200s 501ms/step - 准确率: 0.0097 - 损失: 5.0567 - 验证准确率: 0.0100 - 验证损失: 4.9278
Epoch 2/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 95s 507ms/step - 准确率: 0.0214 - 损失: 4.6918 - 验证准确率: 0.0141 - 验证损失: 5.5380
Epoch 3/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 474ms/step - 准确率: 0.0298 - 损失: 4.4749 - 验证准确率: 0.0375 - 验证损失: 4.4576
Epoch 4/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 90s 479ms/step - 准确率: 0.0423 - 损失: 4.3206 - 验证准确率: 0.0391 - 验证损失: 4.9898
Epoch 5/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 473ms/step - 准确率: 0.0458 - 损失: 4.2312 - 验证准确率: 0.0416 - 验证损失: 4.3210
Epoch 6/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 141s 470ms/step - 准确率: 0.0579 - 损失: 4.1162 - 验证准确率: 0.0540 - 验证损失: 4.3371
Epoch 7/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 476ms/step - 准确率: 0.0679 - 损失: 4.0150 - 验证准确率: 0.0786 - 验证损失: 3.9759
Epoch 8/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 477ms/step - 准确率: 0.0828 - 损失: 3.9147 - 验证准确率: 0.0651 - 验证损失: 4.1641
Epoch 9/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 142s 475ms/step - 准确率: 0.0932 - 损失: 3.8297 - 验证准确率: 0.0928 - 验证损失: 3.8985
Epoch 10/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 88s 472ms/step - 准确率: 0.1092 - 损失: 3.7321 - 验证准确率: 0.0946 - 验证损失: 3.8618
Epoch 11/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 476ms/step - 准确率: 0.1245 - 损失: 3.6451 - 验证准确率: 0.0880 - 验证损失: 3.9584
Epoch 12/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 92s 493ms/step - 准确率: 0.1457 - 损失: 3.5514 - 验证准确率: 0.1096 - 验证损失: 3.8184
Epoch 13/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 88s 471ms/step - 准确率: 0.1606 - 损失: 3.4654 - 验证准确率: 0.1118 - 验证损失: 3.8059
Epoch 14/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 464ms/step - 准确率: 0.1660 - 损失: 3.3826 - 验证准确率: 0.1472 - 验证损失: 3.5726
Epoch 15/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 146s 485ms/step - 准确率: 0.1815 - 损失: 3.2935 - 验证准确率: 0.1154 - 验证损失: 3.8134
Epoch 16/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 466ms/step - 准确率: 0.1942 - 损失: 3.2218 - 验证准确率: 0.1540 - 验证损失: 3.5051
Epoch 17/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 88s 471ms/step - 准确率: 0.2131 - 损失: 3.1427 - 验证准确率: 0.1381 - 验证损失: 3.7206
Epoch 18/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 467ms/step - 准确率: 0.2264 - 损失: 3.0461 - 验证准确率: 0.1707 - 验证损失: 3.4122
Epoch 19/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 88s 470ms/step - 准确率: 0.2401 - 损失: 2.9821 - 验证准确率: 0.1515 - 验证损失: 3.6481
Epoch 20/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 88s 469ms/step - 准确率: 0.2613 - 损失: 2.8815 - 验证准确率: 0.1783 - 验证损失: 3.4767
Epoch 21/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 91s 485ms/step - 准确率: 0.2741 - 损失: 2.8102 - 验证准确率: 0.1927 - 验证损失: 3.3183
Epoch 22/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 90s 477ms/step - 准确率: 0.2892 - 损失: 2.7408 - 验证准确率: 0.1859 - 验证损失: 3.4887
Epoch 23/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 91s 485ms/step - 准确率: 0.3093 - 损失: 2.6526 - 验证准确率: 0.1924 - 验证损失: 3.4622
Epoch 24/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 92s 491ms/step - 准确率: 0.3201 - 损失: 2.5750 - 验证准确率: 0.2253 - 验证损失: 3.1873
Epoch 25/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 95s 508ms/step - 准确率: 0.3280 - 损失: 2.5150 - 验证准确率: 0.2148 - 验证损失: 3.3391
Epoch 26/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 92s 490ms/step - 准确率: 0.3465 - 损失: 2.4402 - 验证准确率: 0.2270 - 验证损失: 3.2679
Epoch 27/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 93s 494ms/step - 准确率: 0.3735 - 损失: 2.3199 - 验证准确率: 0.2080 - 验证损失: 3.5687
Epoch 28/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 476ms/step - 准确率: 0.3837 - 损失: 2.2645 - 验证准确率: 0.2374 - 验证损失: 3.3592
Epoch 29/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 142s 474ms/step - 准确率: 0.3962 - 损失: 2.2110 - 验证准确率: 0.2008 - 验证损失: 3.6071
Epoch 30/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 466ms/step - 准确率: 0.4175 - 损失: 2.1086 - 验证准确率: 0.2302 - 验证损失: 3.4161
Epoch 31/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 465ms/step - 准确率: 0.4359 - 损失: 2.0610 - 验证准确率: 0.2231 - 验证损失: 3.5957
Epoch 32/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 148s 498ms/step - 准确率: 0.4463 - 损失: 1.9866 - 验证准确率: 0.2234 - 验证损失: 3.7263
Epoch 33/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 92s 489ms/step - 准确率: 0.4613 - 损失: 1.8821 - 验证准确率: 0.2239 - 验证损失: 3.6929
Epoch 34/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 139s 475ms/step - 准确率: 0.4925 - 损失: 1.7858 - 验证准确率: 0.2238 - 验证损失: 3.8351
Epoch 35/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 91s 485ms/step - 准确率: 0.5105 - 损失: 1.7074 - 验证准确率: 0.1930 - 验证损失: 4.1941
Epoch 36/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 140s 474ms/step - 准确率: 0.5334 - 损失: 1.6256 - 验证准确率: 0.2098 - 验证损失: 4.1464
Epoch 37/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 87s 464ms/step - 准确率: 0.5504 - 损失: 1.5603 - 验证准确率: 0.2306 - 验证损失: 4.0215
Epoch 38/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 90s 480ms/step - 准确率: 0.5736 - 损失: 1.4419 - 验证准确率: 0.2240 - 验证损失: 4.1604
Epoch 39/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 91s 486ms/step - 准确率: 0.6025 - 损失: 1.3612 - 验证准确率: 0.2344 - 验证损失: 4.0505
Epoch 40/40
187/187 ━━━━━━━━━━━━━━━━━━━━ 89s 474ms/step - 准确率: 0.6199 - 损失: 1.2889 - 验证准确率: 0.2151 - 验证损失: 4.3660
因此,从头开始训练需要非常仔细地选择超参数,并且很难找到合适的正则化。这也将更消耗资源。绘制训练和验证准确率的图表清楚地显示出验证准确率停滞在一个较低的值。
import matplotlib.pyplot as plt
def plot_hist(hist):
plt.plot(hist.history["accuracy"])
plt.plot(hist.history["val_accuracy"])
plt.title("模型准确度")
plt.ylabel("准确度")
plt.xlabel("轮次")
plt.legend(["训练", "验证"], loc="upper left")
plt.show()
plot_hist(hist)
在这里,我们使用预训练的ImageNet权重初始化模型,并在我们自己的数据集上进行微调。
def build_model(num_classes):
inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
model = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")
# 冻结预训练权重
model.trainable = False
# 重建顶部
x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
x = layers.BatchNormalization()(x)
top_dropout_rate = 0.2
x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)
# 编译
model = keras.Model(inputs, outputs, name="EfficientNet")
optimizer = keras.optimizers.Adam(learning_rate=1e-2)
model.compile(
optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
)
return model
迁移学习的第一步是冻结所有层,仅训练顶部层。在这一步中,可以使用相对较大的学习率(1e-2)。请注意,验证准确率和损失通常会优于训练准确率和损失。这是因为正则化很强,只抑制训练时的指标。
请注意,收敛可能需要多达50个轮次,具体取决于学习率的选择。如果没有应用图像增强层,验证准确率可能仅达到~60%。
model = build_model(num_classes=NUM_CLASSES)
epochs = 25 # @param {type: "slider", min:8, max:80}
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
plot_hist(hist)
第 1 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 108s 432ms/step - 准确率: 0.2654 - 损失: 4.3710 - 验证准确率: 0.6888 - 验证损失: 1.0875
第 2 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 119s 412ms/step - 准确率: 0.4863 - 损失: 2.0996 - 验证准确率: 0.7282 - 验证损失: 0.9072
第 3 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 78s 416ms/step - 准确率: 0.5422 - 损失: 1.7120 - 验证准确率: 0.7411 - 验证损失: 0.8574
第 4 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 412ms/step - 准确率: 0.5509 - 损失: 1.6472 - 验证准确率: 0.7451 - 验证损失: 0.8457
第 5 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 81s 431ms/step - 准确率: 0.5744 - 损失: 1.5373 - 验证准确率: 0.7424 - 验证损失: 0.8649
第 6 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 78s 417ms/step - 准确率: 0.5715 - 损失: 1.5595 - 验证准确率: 0.7374 - 验证损失: 0.8736
第 7 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 81s 432ms/step - 准确率: 0.5802 - 损失: 1.5045 - 验证准确率: 0.7430 - 验证损失: 0.8675
第 8 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 411ms/step - 准确率: 0.5839 - 损失: 1.4972 - 验证准确率: 0.7392 - 验证损失: 0.8647
第 9 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 411ms/step - 准确率: 0.5929 - 损失: 1.4699 - 验证准确率: 0.7508 - 验证损失: 0.8634
第 10 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 437ms/step - 准确率: 0.6040 - 损失: 1.4442 - 验证准确率: 0.7520 - 验证损失: 0.8480
第 11 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 78s 416ms/step - 准确率: 0.5972 - 损失: 1.4626 - 验证准确率: 0.7379 - 验证损失: 0.8879
第 12 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 79s 421ms/step - 准确率: 0.5965 - 损失: 1.4700 - 验证准确率: 0.7383 - 验证损失: 0.9409
第 13 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 420ms/step - 准确率: 0.6034 - 损失: 1.4533 - 验证准确率: 0.7474 - 验证损失: 0.8922
第 14 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 81s 435ms/step - 准确率: 0.6053 - 损失: 1.4170 - 验证准确率: 0.7416 - 验证损失: 0.9119
第 15 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 411ms/step - 准确率: 0.6059 - 损失: 1.4125 - 验证准确率: 0.7406 - 验证损失: 0.9205
第 16 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 438ms/step - 准确率: 0.5979 - 损失: 1.4554 - 验证准确率: 0.7392 - 验证损失: 0.9120
第 17 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 411ms/step - 准确率: 0.6081 - 损失: 1.4089 - 验证准确率: 0.7423 - 验证损失: 0.9305
第 18 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 436ms/step - 准确率: 0.6041 - 损失: 1.4390 - 验证准确率: 0.7380 - 验证损失: 0.9644
第 19 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 79s 417ms/step - 准确率: 0.6018 - 损失: 1.4324 - 验证准确率: 0.7439 - 验证损失: 0.9129
第 20 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 81s 430ms/step - 准确率: 0.6057 - 损失: 1.4342 - 验证准确率: 0.7305 - 验证损失: 0.9463
第 21 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 410ms/step - 准确率: 0.6209 - 损失: 1.3824 - 验证准确率: 0.7410 - 验证损失: 0.9503
第 22 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 78s 419ms/step - 准确率: 0.6170 - 损失: 1.4246 - 验证准确率: 0.7336 - 验证损失: 0.9606
第 23 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 85s 455ms/step - 准确率: 0.6153 - 损失: 1.4009 - 验证准确率: 0.7334 - 验证损失: 0.9520
第 24 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 438ms/step - 准确率: 0.6051 - 损失: 1.4343 - 验证准确率: 0.7435 - 验证损失: 0.9403
第 25 轮/25
187/187 ━━━━━━━━━━━━━━━━━━━━ 138s 416ms/step - 准确率: 0.6065 - 损失: 1.4131 - 验证准确率: 0.7456 - 验证损失: 0.9307
第二步是解冻一些层,并使用较小的学习率拟合模型。在这个例子中,我们展示了解冻所有层,但根据特定数据集,可能只想解冻部分层。
当使用预训练模型进行特征提取效果足够好时,这一步对验证准确率的提升非常有限。在我们的案例中,我们只看到小幅提升,因为ImageNet预训练已经让模型接触了相当数量的狗。
另一方面,当我们在一个与ImageNet差异较大的数据集上使用预训练权重时,这一步微调非常关键,因为特征提取器也需要作出相当大的调整。如果选择CIFAR-100数据集,微调可以将验证准确率提升约10%,使EfficientNetB0
的准确率超过80%。
关于冻结/解冻模型的附注:设置Model
的trainable
属性将同时设置所有属于该Model
的层为相同的trainable
属性。每一层只有在层本身和包含它的模型都是可训练时才是可训练的。因此,当我们需要部分冻结/解冻模型时,需要确保模型的trainable
属性设置为True
。
def unfreeze_model(model):
# 我们解冻顶部20层,同时保持BatchNorm层被冻结
for layer in model.layers[-20:]:
if not isinstance(layer, layers.BatchNormalization):
layer.trainable = True
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
model.compile(
optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
)
unfreeze_model(model)
epochs = 4 # @param {type: "slider", min:4, max:10}
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
plot_hist(hist)
Epoch 1/4
187/187 ━━━━━━━━━━━━━━━━━━━━ 111s 442ms/step - accuracy: 0.6310 - loss: 1.3425 - val_accuracy: 0.7565 - val_loss: 0.8874
Epoch 2/4
187/187 ━━━━━━━━━━━━━━━━━━━━ 77s 413ms/step - accuracy: 0.6518 - loss: 1.2755 - val_accuracy: 0.7635 - val_loss: 0.8588
Epoch 3/4
187/187 ━━━━━━━━━━━━━━━━━━━━ 82s 437ms/step - accuracy: 0.6491 - loss: 1.2426 - val_accuracy: 0.7663 - val_loss: 0.8419
Epoch 4/4
187/187 ━━━━━━━━━━━━━━━━━━━━ 79s 419ms/step - accuracy: 0.6625 - loss: 1.1775 - val_accuracy: 0.7701 - val_loss: 0.8284
关于解冻层:
BatchNormalization
层需要保持被冻结状态
(更多细节)。
如果它们也被设置为可训练,解冻后的第一个时期将显著降低准确率。利用EfficientNet的其他一些提示: