作者: Hamid Ali
创建日期: 2023/05/30
最后修改: 2023/07/13
描述: 基于DUTS数据集训练的边界感知分割模型。
深度语义分割算法最近有所改善,但仍然无法正确预测物体边界周围的像素。在这个例子中,我们实现了 边界感知分割网络(BASNet),使用两个阶段的预测和精炼架构,结合混合损失,可以预测高度准确的边界和精细结构,用于图像分割。
我们将使用DUTS-TE数据集进行训练。它有5,019 张图像,但我们将使用140张进行训练和验证,以节省笔记本运行时间。DUTS是 一个相对较大的显著对象分割数据集,包含多样化的质地和 结构,常见于前景和背景的真实世界图像中。
!wget http://saliencydetection.net/duts/download/DUTS-TE.zip
!unzip -q DUTS-TE.zip
--2023-08-06 19:07:37-- http://saliencydetection.net/duts/download/DUTS-TE.zip
正在解析 saliencydetection.net (saliencydetection.net)... 36.55.239.177
连接到 saliencydetection.net (saliencydetection.net)|36.55.239.177|:80... 已连接.
发送HTTP请求,等待响应... 200 OK
长度: 139799089 (133M) [application/zip]
正在保存到: ‘DUTS-TE.zip’
DUTS-TE.zip 100%[===================>] 133.32M 1.76MB/s 用时 77s
2023-08-06 19:08:55 (1.73 MB/s) - ‘DUTS-TE.zip’ 保存完成 [139799089/139799089]
import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import keras_cv
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, backend
使用 TensorFlow 后端
IMAGE_SIZE = 288
BATCH_SIZE = 4
OUT_CLASSES = 1
TRAIN_SPLIT_RATIO = 0.90
DATA_DIR = "./DUTS-TE/"
我们将使用 load_paths()
加载并将140个路径拆分为训练和验证集,并使用
load_dataset()
将路径转换为 tf.data.Dataset
对象。
def load_paths(path, split_ratio):
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:140]
len_ = int(len(images) * split_ratio)
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
def read_image(path, size, mode):
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
x = keras.utils.img_to_array(x)
x = (x / 255.0).astype(np.float32)
return x
def preprocess(x_batch, y_batch, img_size, out_classes):
def f(_x, _y):
_x, _y = _x.decode(), _y.decode()
_x = read_image(_x, (img_size, img_size), mode="rgb") # 图像
_y = read_image(_y, (img_size, img_size), mode="grayscale") # 掩码
return _x, _y
images, masks = tf.numpy_function(f, [x_batch, y_batch], [tf.float32, tf.float32])
images.set_shape([img_size, img_size, 3])
masks.set_shape([img_size, img_size, out_classes])
return images, masks
def load_dataset(image_paths, mask_paths, img_size, out_classes, batch, shuffle=True):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
if shuffle:
dataset = dataset.cache().shuffle(buffer_size=1000)
dataset = dataset.map(
lambda x, y: preprocess(x, y, img_size, out_classes),
num_parallel_calls=tf.data.AUTOTUNE,
)
dataset = dataset.batch(batch)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)
train_dataset = load_dataset(
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
)
val_dataset = load_dataset(
val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
)
print(f"训练数据集: {train_dataset}")
print(f"验证数据集: {val_dataset}")
训练数据集: <_PrefetchDataset element_spec=(TensorSpec(shape=(None, 288, 288, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 288, 288, 1), dtype=tf.float32, name=None))>
验证数据集: <_PrefetchDataset element_spec=(TensorSpec(shape=(None, 288, 288, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 288, 288, 1), dtype=tf.float32, name=None))>
def display(display_list):
title = ["输入图像", "真实掩膜", "预测掩膜"]
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i + 1)
plt.title(title[i])
plt.imshow(keras.utils.array_to_img(display_list[i]), cmap="gray")
plt.axis("off")
plt.show()
for image, mask in val_dataset.take(1):
display([image[0], mask[0]])
让我们打印上述显示的掩膜的唯一值。你会发现尽管属于同一类,它的 强度在低(0)到高(255)之间变化。这种强度变化使得 网络很难为显著或伪装物体分割生成良好的分割图。 由于其残差精炼模块 (RMs),BASNet 在生成高精度 边界和细结构方面表现良好。
print(f"唯一值计数: {len(np.unique((mask[0] * 255)))}")
print("唯一值:")
print(np.unique((mask[0] * 255)).astype(int))
唯一值计数: 245
唯一值:
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
54 55 56 57 58 59 61 62 63 65 66 67 68 69 70 71 73 74
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
93 94 95 96 97 98 99 100 101 102 103 104 105 108 109 110 111 112
113 114 115 116 117 118 119 120 122 123 124 125 128 129 130 131 132 133
134 135 136 137 138 139 140 141 142 144 145 146 147 148 149 150 151 152
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 170 171
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
190 191 192 193 194 195 196 197 198 199 201 202 203 204 205 206 207 208
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
245 246 247 248 249 250 251 252 253 254 255]
BASNet 包含一个预测-精炼架构和一个混合损失。预测-精炼 架构由一个密集监督的编码器-解码器网络和一个残差精炼 模块组成,分别用于预测和精炼分割概率图。
def basic_block(x_input, filters, stride=1, down_sample=None, activation=None):
"""创建一个具有两个 3*3 卷积的残差(恒等)块。"""
residual = x_input
x = layers.Conv2D(filters, (3, 3), strides=stride, padding="same", use_bias=False)(
x_input
)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(filters, (3, 3), strides=(1, 1), padding="same", use_bias=False)(
x
)
x = layers.BatchNormalization()(x)
if down_sample is not None:
residual = down_sample
x = layers.Add()([x, residual])
if activation is not None:
x = layers.Activation(activation)(x)
return x
def convolution_block(x_input, filters, dilation=1):
"""应用卷积 + 批归一化 + ReLU 层。"""
x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_input)
x = layers.BatchNormalization()(x)
return layers.Activation("relu")(x)
def segmentation_head(x_input, out_classes, final_size):
"""将每个解码器阶段输出映射到模型输出类。"""
x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_input)
if final_size is not None:
x = layers.Resizing(final_size[0], final_size[1])(x)
return x
def get_resnet_block(_resnet, block_num):
"""提取并返回 ResNet-34 块。"""
resnet_layers = [3, 4, 6, 3] # 不同块中 ResNet-34 的层大小。
return keras.models.Model(
inputs=_resnet.get_layer(f"v2_stack_{block_num}_block1_1_conv").input,
outputs=_resnet.get_layer(
f"v2_stack_{block_num}_block{resnet_layers[block_num]}_add"
).output,
name=f"resnet34_block{block_num + 1}",
)
预测模块是一个像 U-Net 一样的重型编码器解码器结构。编码器包括一个输入
卷积层和六个阶段。前四个阶段来自 ResNet-34,其余是基本
res-blocks。由于ResNet-34的第一个卷积层和池化层被跳过,因此我们将使用
get_resnet_block()
提取前四个块。桥接和解码器使用三个带有侧输出的卷积层。该模块在训练期间生成七个分割概率图,最后一个被视为最终输出。
def basnet_predict(input_shape, out_classes):
"""BASNet预测模块,输出粗略标签图。"""
filters = 64
num_stages = 6
x_input = layers.Input(input_shape)
# -------------编码器--------------
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
resnet = keras_cv.models.ResNet34Backbone(
include_rescaling=False,
)
encoder_blocks = []
for i in range(num_stages):
if i < 4: # 前四个阶段采用ResNet-34块。
x = get_resnet_block(resnet, i)(x)
encoder_blocks.append(x)
x = layers.Activation("relu")(x)
else: # 最后两个阶段由三个基本resnet块组成。
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
x = basic_block(x, filters=filters * 8, activation="relu")
x = basic_block(x, filters=filters * 8, activation="relu")
x = basic_block(x, filters=filters * 8, activation="relu")
encoder_blocks.append(x)
# -------------桥接-------------
x = convolution_block(x, filters=filters * 8, dilation=2)
x = convolution_block(x, filters=filters * 8, dilation=2)
x = convolution_block(x, filters=filters * 8, dilation=2)
encoder_blocks.append(x)
# -------------解码器-------------
decoder_blocks = []
for i in reversed(range(num_stages)):
if i != (num_stages - 1): # 除了第一个,缩放其他解码器阶段。
shape = keras.backend.int_shape(x)
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
x = convolution_block(x, filters=filters * 8)
x = convolution_block(x, filters=filters * 8)
x = convolution_block(x, filters=filters * 8)
decoder_blocks.append(x)
decoder_blocks.reverse() # 将顺序从最后一个解码器阶段改为第一个。
decoder_blocks.append(encoder_blocks[-1]) # 将桥接复制到解码器。
# -------------侧输出--------------
decoder_blocks = [
segmentation_head(decoder_block, out_classes, input_shape[:2])
for decoder_block in decoder_blocks
]
return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
精炼模块(RMs)设计为残差块,旨在提炼由预测模块生成的粗糙(模糊和噪声)分割图。与预测模块类似,它也是一个编码解码结构,但具有轻量级的4个阶段,每个阶段包含一个
convolutional block()
初始化。最后,它将粗略输出和残差输出相加,生成精炼输出。
def basnet_rrm(base_model, out_classes):
"""BASNet残差精炼模块(RRM),输出细致标签图。"""
num_stages = 4
filters = 64
x_input = base_model.output[0]
# -------------编码器--------------
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
encoder_blocks = []
for _ in range(num_stages):
x = convolution_block(x, filters=filters)
encoder_blocks.append(x)
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
# -------------桥接--------------
x = convolution_block(x, filters=filters)
# -------------解码器--------------
for i in reversed(range(num_stages)):
shape = keras.backend.int_shape(x)
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
x = convolution_block(x, filters=filters)
x = segmentation_head(x, out_classes, None) # 分割头。
# ------------- refined = coarse + residual
x = layers.Add()([x_input, x]) # 相加预测 + 精炼输出
return keras.models.Model(inputs=[base_model.input], outputs=[x])
def basnet(input_shape, out_classes):
"""BASNet,它是两个模块的结合
预测模块和残差精炼模块(RRM)。"""
# 预测模型。
predict_model = basnet_predict(input_shape, out_classes)
# 精炼模型。
refine_model = basnet_rrm(predict_model, out_classes)
output = [refine_model.output] # 合并输出。
output.extend(predict_model.output)
output = [layers.Activation("sigmoid")(_) for _ in output] # 激活。
return keras.models.Model(inputs=[predict_model.input], outputs=output)
BASNet的另一个重要特征是其混合损失函数,它是一个组合的 binary cross entropy、结构相似性和交集比损失,指导网络学习三层级(即像素、补丁和地图级别)层次表示。
class BasnetLoss(keras.losses.Loss):
"""BASNet混合损失。"""
def __init__(self, **kwargs):
super().__init__(name="basnet_loss", **kwargs)
self.smooth = 1.0e-9
# 二元交叉熵损失。
self.cross_entropy_loss = keras.losses.BinaryCrossentropy()
# 结构相似性指标值。
self.ssim_value = tf.image.ssim
# Jaccard / IoU损失。
self.iou_value = self.calculate_iou
def calculate_iou(
self,
y_true,
y_pred,
):
"""计算图像之间的交集比(IoU)。"""
intersection = backend.sum(backend.abs(y_true * y_pred), axis=[1, 2, 3])
union = backend.sum(y_true, [1, 2, 3]) + backend.sum(y_pred, [1, 2, 3])
union = union - intersection
return backend.mean(
(intersection + self.smooth) / (union + self.smooth), axis=0
)
def call(self, y_true, y_pred):
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
ssim_loss = backend.mean(1 - ssim_value + self.smooth, axis=0)
iou_value = self.iou_value(y_true, y_pred)
iou_loss = 1 - iou_value
# 添加所有三种损失。
return cross_entropy_loss + ssim_loss + iou_loss
basnet_model = basnet(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES
) # 创建模型。
basnet_model.summary() # 显示模型摘要。
optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
# 编译模型。
basnet_model.compile(
loss=BasnetLoss(),
optimizer=optimizer,
metrics=[keras.metrics.MeanAbsoluteError(name="mae")],
)
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 288, 288, 3)] 0 []
conv2d (Conv2D) (None, 288, 288, 64) 1792 ['input_1[0][0]']
resnet34_block1 (Functiona (None, None, None, 64) 222720 ['conv2d[0][0]']
l)
activation (Activation) (None, 288, 288, 64) 0 ['resnet34_block1[0][0]']
resnet34_block2 (Functiona (None, None, None, 128) 1118720 ['activation[0][0]']
l)
activation_1 (Activation) (None, 144, 144, 128) 0 ['resnet34_block2[0][0]']
resnet34_block3 (Functiona (None, None, None, 256) 6829056 ['activation_1[0][0]']
l)
activation_2 (Activation) (None, 72, 72, 256) 0 ['resnet34_block3[0][0]']
resnet34_block4 (Functiona (None, None, None, 512) 1312153 ['activation_2[0][0]']
l) 6
activation_3 (Activation) (None, 36, 36, 512) 0 ['resnet34_block4[0][0]']
max_pooling2d (MaxPooling2 (None, 18, 18, 512) 0 ['activation_3[0][0]']
D)
conv2d_1 (Conv2D) (None, 18, 18, 512) 2359296 ['max_pooling2d[0][0]']
batch_normalization (Batch (None, 18, 18, 512) 2048 ['conv2d_1[0][0]']
Normalization)
activation_4 (Activation) (None, 18, 18, 512) 0 ['batch_normalization[0][0]']
conv2d_2 (Conv2D) (None, 18, 18, 512) 2359296 ['activation_4[0][0]']
batch_normalization_1 (Bat (None, 18, 18, 512) 2048 ['conv2d_2[0][0]']
chNormalization)
add (Add) (None, 18, 18, 512) 0 ['batch_normalization_1[0][0]'
, 'max_pooling2d[0][0]']
activation_5 (Activation) (None, 18, 18, 512) 0 ['add[0][0]']
conv2d_3 (Conv2D) (None, 18, 18, 512) 2359296 ['activation_5[0][0]']
batch_normalization_2 (Bat (None, 18, 18, 512) 2048 ['conv2d_3[0][0]']
chNormalization)
activation_6 (Activation) (None, 18, 18, 512) 0 ['batch_normalization_2[0][0]'
]
conv2d_4 (Conv2D) (None, 18, 18, 512) 2359296 ['activation_6[0][0]']
batch_normalization_3 (Bat (None, 18, 18, 512) 2048 ['conv2d_4[0][0]']
chNormalization)
add_1 (Add) (None, 18, 18, 512) 0 ['batch_normalization_3[0][0]'
, 'activation_5[0][0]']
activation_7 (Activation) (None, 18, 18, 512) 0 ['add_1[0][0]']
conv2d_5 (Conv2D) (None, 18, 18, 512) 2359296 ['activation_7[0][0]']
batch_normalization_4 (Bat (None, 18, 18, 512) 2048 ['conv2d_5[0][0]']
chNormalization)
activation_8 (Activation) (None, 18, 18, 512) 0 ['batch_normalization_4[0][0]'
]
conv2d_6 (Conv2D) (None, 18, 18, 512) 2359296 ['activation_8[0][0]']
batch_normalization_5 (Bat (None, 18, 18, 512) 2048 ['conv2d_6[0][0]']
chNormalization)
add_2 (Add) (None, 18, 18, 512) 0 ['batch_normalization_5[0][0]'
, 'activation_7[0][0]']
activation_9 (Activation) (None, 18, 18, 512) 0 ['add_2[0][0]']
max_pooling2d_1 (MaxPoolin (None, 9, 9, 512) 0 ['activation_9[0][0]']
g2D)
conv2d_7 (Conv2D) (None, 9, 9, 512) 2359296 ['max_pooling2d_1[0][0]']
batch_normalization_6 (Bat (None, 9, 9, 512) 2048 ['conv2d_7[0][0]']
chNormalization)
activation_10 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_6[0][0]'
]
conv2d_8 (Conv2D) (None, 9, 9, 512) 2359296 ['activation_10[0][0]']
batch_normalization_7 (Bat (None, 9, 9, 512) 2048 ['conv2d_8[0][0]']
chNormalization)
add_3 (Add) (None, 9, 9, 512) 0 ['batch_normalization_7[0][0]'
, 'max_pooling2d_1[0][0]']
activation_11 (Activation) (None, 9, 9, 512) 0 ['add_3[0][0]']
conv2d_9 (Conv2D) (None, 9, 9, 512) 2359296 ['activation_11[0][0]']
batch_normalization_8 (Bat (None, 9, 9, 512) 2048 ['conv2d_9[0][0]']
chNormalization)
activation_12 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_8[0][0]'
]
conv2d_10 (Conv2D) (None, 9, 9, 512) 2359296 ['activation_12[0][0]']
batch_normalization_9 (Bat (None, 9, 9, 512) 2048 ['conv2d_10[0][0]']
chNormalization)
add_4 (Add) (None, 9, 9, 512) 0 ['batch_normalization_9[0][0]'
, 'activation_11[0][0]']
activation_13 (Activation) (None, 9, 9, 512) 0 ['add_4[0][0]']
conv2d_11 (Conv2D) (None, 9, 9, 512) 2359296 ['activation_13[0][0]']
batch_normalization_10 (Ba (None, 9, 9, 512) 2048 ['conv2d_11[0][0]']
tchNormalization)
activation_14 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_10[0][0]
']
conv2d_12 (Conv2D) (None, 9, 9, 512) 2359296 ['activation_14[0][0]']
batch_normalization_11 (Ba (None, 9, 9, 512) 2048 ['conv2d_12[0][0]']
tchNormalization)
add_5 (Add) (None, 9, 9, 512) 0 ['batch_normalization_11[0][0]
',
'activation_13[0][0]']
activation_15 (Activation) (None, 9, 9, 512) 0 ['add_5[0][0]']
conv2d_13 (Conv2D) (None, 9, 9, 512) 2359808 ['activation_15[0][0]']
batch_normalization_12 (Ba (None, 9, 9, 512) 2048 ['conv2d_13[0][0]']
tchNormalization)
activation_16 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_12[0][0]
']
conv2d_14 (Conv2D) (None, 9, 9, 512) 2359808 ['activation_16[0][0]']
batch_normalization_13 (Ba (None, 9, 9, 512) 2048 ['conv2d_14[0][0]']
tchNormalization)
activation_17 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_13[0][0]
']
conv2d_15 (Conv2D) (None, 9, 9, 512) 2359808 ['activation_17[0][0]']
batch_normalization_14 (Ba (None, 9, 9, 512) 2048 ['conv2d_15[0][0]']
tchNormalization)
activation_18 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_14[0][0]
']
concatenate (Concatenate) (None, 9, 9, 1024) 0 ['activation_15[0][0]',
'activation_18[0][0]']
conv2d_16 (Conv2D) (None, 9, 9, 512) 4719104 ['concatenate[0][0]']
batch_normalization_15 (Ba (None, 9, 9, 512) 2048 ['conv2d_16[0][0]']
tchNormalization)
activation_19 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_15[0][0]
']
conv2d_17 (Conv2D) (None, 9, 9, 512) 2359808 ['activation_19[0][0]']
batch_normalization_16 (Ba (None, 9, 9, 512) 2048 ['conv2d_17[0][0]']
tchNormalization)
activation_20 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_16[0][0]
']
conv2d_18 (Conv2D) (None, 9, 9, 512) 2359808 ['activation_20[0][0]']
batch_normalization_17 (Ba (None, 9, 9, 512) 2048 ['conv2d_18[0][0]']
tchNormalization)
activation_21 (Activation) (None, 9, 9, 512) 0 ['batch_normalization_17[0][0]
']
resizing (Resizing) (None, 18, 18, 512) 0 ['activation_21[0][0]']
concatenate_1 (Concatenate (None, 18, 18, 1024) 0 ['activation_9[0][0]',
) 'resizing[0][0]']
conv2d_19 (Conv2D) (None, 18, 18, 512) 4719104 ['concatenate_1[0][0]']
batch_normalization_18 (Ba (None, 18, 18, 512) 2048 ['conv2d_19[0][0]']
tchNormalization)
activation_22 (Activation) (None, 18, 18, 512) 0 ['batch_normalization_18[0][0]
']
conv2d_20 (Conv2D) (None, 18, 18, 512) 2359808 ['activation_22[0][0]']
batch_normalization_19 (Ba (None, 18, 18, 512) 2048 ['conv2d_20[0][0]']
tchNormalization)
activation_23 (Activation) (None, 18, 18, 512) 0 ['batch_normalization_19[0][0]
']
conv2d_21 (Conv2D) (None, 18, 18, 512) 2359808 ['activation_23[0][0]']
batch_normalization_20 (Ba (None, 18, 18, 512) 2048 ['conv2d_21[0][0]']
tchNormalization)
activation_24 (Activation) (None, 18, 18, 512) 0 ['batch_normalization_20[0][0]
']
resizing_1 (Resizing) (None, 36, 36, 512) 0 ['activation_24[0][0]']
concatenate_2 (Concatenate (None, 36, 36, 1024) 0 ['resnet34_block4[0][0]',
) 'resizing_1[0][0]']
conv2d_22 (Conv2D) (None, 36, 36, 512) 4719104 ['concatenate_2[0][0]']
batch_normalization_21 (Ba (None, 36, 36, 512) 2048 ['conv2d_22[0][0]']
tchNormalization)
activation_25 (Activation) (None, 36, 36, 512) 0 ['batch_normalization_21[0][0]
']
conv2d_23 (Conv2D) (None, 36, 36, 512) 2359808 ['activation_25[0][0]']
batch_normalization_22 (Ba (None, 36, 36, 512) 2048 ['conv2d_23[0][0]']
tchNormalization)
activation_26 (Activation) (None, 36, 36, 512) 0 ['batch_normalization_22[0][0]
']
conv2d_24 (Conv2D) (None, 36, 36, 512) 2359808 ['activation_26[0][0]']
batch_normalization_23 (Ba (None, 36, 36, 512) 2048 ['conv2d_24[0][0]']
tchNormalization)
activation_27 (Activation) (None, 36, 36, 512) 0 ['batch_normalization_23[0][0]
']
resizing_2 (Resizing) (None, 72, 72, 512) 0 ['activation_27[0][0]']
concatenate_3 (Concatenate (None, 72, 72, 768) 0 ['resnet34_block3[0][0]',
) 'resizing_2[0][0]']
conv2d_25 (Conv2D) (None, 72, 72, 512) 3539456 ['concatenate_3[0][0]']
batch_normalization_24 (Ba (None, 72, 72, 512) 2048 ['conv2d_25[0][0]']
tchNormalization)
activation_28 (Activation) (None, 72, 72, 512) 0 ['batch_normalization_24[0][0]
']
conv2d_26 (Conv2D) (None, 72, 72, 512) 2359808 ['activation_28[0][0]']
batch_normalization_25 (Ba (None, 72, 72, 512) 2048 ['conv2d_26[0][0]']
tchNormalization)
activation_29 (Activation) (None, 72, 72, 512) 0 ['batch_normalization_25[0][0]
']
conv2d_27 (Conv2D) (None, 72, 72, 512) 2359808 ['activation_29[0][0]']
batch_normalization_26 (Ba (None, 72, 72, 512) 2048 ['conv2d_27[0][0]']
tchNormalization)
activation_30 (Activation) (None, 72, 72, 512) 0 ['batch_normalization_26[0][0]
']
resizing_3 (Resizing) (None, 144, 144, 512) 0 ['activation_30[0][0]']
concatenate_4 (Concatenate (None, 144, 144, 640) 0 ['resnet34_block2[0][0]',
) 'resizing_3[0][0]']
conv2d_28 (Conv2D) (None, 144, 144, 512) 2949632 ['concatenate_4[0][0]']
batch_normalization_27 (Ba (None, 144, 144, 512) 2048 ['conv2d_28[0][0]']
tchNormalization)
activation_31 (Activation) (None, 144, 144, 512) 0 ['batch_normalization_27[0][0]
']
conv2d_29 (Conv2D) (None, 144, 144, 512) 2359808 ['activation_31[0][0]']
batch_normalization_28 (Ba (None, 144, 144, 512) 2048 ['conv2d_29[0][0]']
tchNormalization)
activation_32 (Activation) (None, 144, 144, 512) 0 ['batch_normalization_28[0][0]
']
conv2d_30 (Conv2D) (None, 144, 144, 512) 2359808 ['activation_32[0][0]']
batch_normalization_29 (Ba (None, 144, 144, 512) 2048 ['conv2d_30[0][0]']
tchNormalization)
activation_33 (Activation) (None, 144, 144, 512) 0 ['batch_normalization_29[0][0]
']
resizing_4 (Resizing) (None, 288, 288, 512) 0 ['activation_33[0][0]']
concatenate_5 (Concatenate (None, 288, 288, 576) 0 ['resnet34_block1[0][0]',
) 'resizing_4[0][0]']
conv2d_31 (Conv2D) (None, 288, 288, 512) 2654720 ['concatenate_5[0][0]']
batch_normalization_30 (Ba (None, 288, 288, 512) 2048 ['conv2d_31[0][0]']
tchNormalization)
activation_34 (Activation) (None, 288, 288, 512) 0 ['batch_normalization_30[0][0]
']
conv2d_32 (Conv2D) (None, 288, 288, 512) 2359808 ['activation_34[0][0]']
batch_normalization_31 (Ba (None, 288, 288, 512) 2048 ['conv2d_32[0][0]']
tchNormalization)
activation_35 (Activation) (None, 288, 288, 512) 0 ['batch_normalization_31[0][0]
']
conv2d_33 (Conv2D) (None, 288, 288, 512) 2359808 ['activation_35[0][0]']
batch_normalization_32 (Ba (None, 288, 288, 512) 2048 ['conv2d_33[0][0]']
tchNormalization)
activation_36 (Activation) (None, 288, 288, 512) 0 ['batch_normalization_32[0][0]
']
conv2d_34 (Conv2D) (None, 288, 288, 1) 4609 ['activation_36[0][0]']
resizing_5 (Resizing) (None, 288, 288, 1) 0 ['conv2d_34[0][0]']
conv2d_41 (Conv2D) (None, 288, 288, 64) 640 ['resizing_5[0][0]']
conv2d_42 (Conv2D) (None, 288, 288, 64) 36928 ['conv2d_41[0][0]']
batch_normalization_33 (Ba (None, 288, 288, 64) 256 ['conv2d_42[0][0]']
tchNormalization)
activation_37 (Activation) (None, 288, 288, 64) 0 ['batch_normalization_33[0][0]
']
max_pooling2d_2 (MaxPoolin (None, 144, 144, 64) 0 ['activation_37[0][0]']
g2D)
conv2d_43 (Conv2D) (None, 144, 144, 64) 36928 ['max_pooling2d_2[0][0]']
batch_normalization_34 (Ba (None, 144, 144, 64) 256 ['conv2d_43[0][0]']
tchNormalization)
activation_38 (Activation) (None, 144, 144, 64) 0 ['batch_normalization_34[0][0]
']
max_pooling2d_3 (MaxPoolin (None, 72, 72, 64) 0 ['activation_38[0][0]']
g2D)
conv2d_44 (Conv2D) (None, 72, 72, 64) 36928 ['max_pooling2d_3[0][0]']
batch_normalization_35 (Ba (None, 72, 72, 64) 256 ['conv2d_44[0][0]']
tchNormalization)
activation_39 (Activation) (None, 72, 72, 64) 0 ['batch_normalization_35[0][0]
']
max_pooling2d_4 (MaxPoolin (None, 36, 36, 64) 0 ['activation_39[0][0]']
g2D)
conv2d_45 (Conv2D) (None, 36, 36, 64) 36928 ['max_pooling2d_4[0][0]']
batch_normalization_36 (Ba (None, 36, 36, 64) 256 ['conv2d_45[0][0]']
tchNormalization)
activation_40 (Activation) (None, 36, 36, 64) 0 ['batch_normalization_36[0][0]
']
max_pooling2d_5 (MaxPoolin (None, 18, 18, 64) 0 ['activation_40[0][0]']
g2D)
conv2d_46 (Conv2D) (None, 18, 18, 64) 36928 ['max_pooling2d_5[0][0]']
batch_normalization_37 (Ba (None, 18, 18, 64) 256 ['conv2d_46[0][0]']
tchNormalization)
activation_41 (Activation) (None, 18, 18, 64) 0 ['batch_normalization_37[0][0]
']
resizing_12 (Resizing) (None, 36, 36, 64) 0 ['activation_41[0][0]']
concatenate_6 (Concatenate (None, 36, 36, 128) 0 ['activation_40[0][0]',
) 'resizing_12[0][0]']
conv2d_47 (Conv2D) (None, 36, 36, 64) 73792 ['concatenate_6[0][0]']
batch_normalization_38 (Ba (None, 36, 36, 64) 256 ['conv2d_47[0][0]']
tchNormalization)
activation_42 (Activation) (None, 36, 36, 64) 0 ['batch_normalization_38[0][0]
']
resizing_13 (Resizing) (None, 72, 72, 64) 0 ['activation_42[0][0]']
concatenate_7 (Concatenate (None, 72, 72, 128) 0 ['activation_39[0][0]',
) 'resizing_13[0][0]']
conv2d_48 (Conv2D) (None, 72, 72, 64) 73792 ['concatenate_7[0][0]']
batch_normalization_39 (Ba (None, 72, 72, 64) 256 ['conv2d_48[0][0]']
tchNormalization)
activation_43 (Activation) (None, 72, 72, 64) 0 ['batch_normalization_39[0][0]
']
resizing_14 (Resizing) (None, 144, 144, 64) 0 ['activation_43[0][0]']
concatenate_8 (Concatenate (None, 144, 144, 128) 0 ['activation_38[0][0]',
) 'resizing_14[0][0]']
conv2d_49 (Conv2D) (None, 144, 144, 64) 73792 ['concatenate_8[0][0]']
batch_normalization_40 (Ba (None, 144, 144, 64) 256 ['conv2d_49[0][0]']
tchNormalization)
activation_44 (Activation) (None, 144, 144, 64) 0 ['batch_normalization_40[0][0]
']
resizing_15 (Resizing) (None, 288, 288, 64) 0 ['activation_44[0][0]']
concatenate_9 (Concatenate (None, 288, 288, 128) 0 ['activation_37[0][0]',
) 'resizing_15[0][0]']
conv2d_50 (Conv2D) (None, 288, 288, 64) 73792 ['concatenate_9[0][0]']
batch_normalization_41 (Ba (None, 288, 288, 64) 256 ['conv2d_50[0][0]']
tchNormalization)
activation_45 (Activation) (None, 288, 288, 64) 0 ['batch_normalization_41[0][0]
']
conv2d_51 (Conv2D) (None, 288, 288, 1) 577 ['activation_45[0][0]']
conv2d_35 (Conv2D) (None, 144, 144, 1) 4609 ['activation_33[0][0]']
conv2d_36 (Conv2D) (None, 72, 72, 1) 4609 ['activation_30[0][0]']
conv2d_37 (Conv2D) (None, 36, 36, 1) 4609 ['activation_27[0][0]']
conv2d_38 (Conv2D) (None, 18, 18, 1) 4609 ['activation_24[0][0]']
conv2d_39 (Conv2D) (None, 9, 9, 1) 4609 ['activation_21[0][0]']
conv2d_40 (Conv2D) (None, 9, 9, 1) 4609 ['activation_18[0][0]']
add_6 (Add) (None, 288, 288, 1) 0 ['resizing_5[0][0]',
'conv2d_51[0][0]']
resizing_6 (Resizing) (None, 288, 288, 1) 0 ['conv2d_35[0][0]']
resizing_7 (Resizing) (None, 288, 288, 1) 0 ['conv2d_36[0][0]']
resizing_8 (Resizing) (None, 288, 288, 1) 0 ['conv2d_37[0][0]']
resizing_9 (Resizing) (None, 288, 288, 1) 0 ['conv2d_38[0][0]']
resizing_10 (Resizing) (None, 288, 288, 1) 0 ['conv2d_39[0][0]']
resizing_11 (Resizing) (None, 288, 288, 1) 0 ['conv2d_40[0][0]']
activation_46 (Activation) (None, 288, 288, 1) 0 ['add_6[0][0]']
activation_47 (Activation) (None, 288, 288, 1) 0 ['resizing_5[0][0]']
activation_48 (Activation) (None, 288, 288, 1) 0 ['resizing_6[0][0]']
activation_49 (Activation) (None, 288, 288, 1) 0 ['resizing_7[0][0]']
activation_50 (Activation) (None, 288, 288, 1) 0 ['resizing_8[0][0]']
activation_51 (Activation) (None, 288, 288, 1) 0 ['resizing_9[0][0]']
activation_52 (Activation) (None, 288, 288, 1) 0 ['resizing_10[0][0]']
activation_53 (Activation) (None, 288, 288, 1) 0 ['resizing_11[0][0]']
==================================================================================================
Total params: 108886792 (415.37 MB)
Trainable params: 108834952 (415.17 MB)
Non-trainable params: 51840 (202.50 KB)
__________________________________________________________________________________________________
basnet_model.fit(train_dataset, validation_data=val_dataset, epochs=1)
32/32 [==============================] - 153s 2s/step - loss: 16.3507 - activation_46_loss: 2.1445 - activation_47_loss: 2.1512 - activation_48_loss: 2.0621 - activation_49_loss: 2.0755 - activation_50_loss: 2.1406 - activation_51_loss: 1.9035 - activation_52_loss: 1.8702 - activation_53_loss: 2.0031 - activation_46_mae: 0.2972 - activation_47_mae: 0.3126 - activation_48_mae: 0.2793 - activation_49_mae: 0.2887 - activation_50_mae: 0.3280 - activation_51_mae: 0.2548 - activation_52_mae: 0.2330 - activation_53_mae: 0.2564 - val_loss: 18.4498 - val_activation_46_loss: 2.3113 - val_activation_47_loss: 2.3143 - val_activation_48_loss: 2.3356 - val_activation_49_loss: 2.3093 - val_activation_50_loss: 2.3187 - val_activation_51_loss: 2.3943 - val_activation_52_loss: 2.2712 - val_activation_53_loss: 2.1952 - val_activation_46_mae: 0.2770 - val_activation_47_mae: 0.2681 - val_activation_48_mae: 0.2424 - val_activation_49_mae: 0.2691 - val_activation_50_mae: 0.2765 - val_activation_51_mae: 0.1907 - val_activation_52_mae: 0.1885 - val_activation_53_mae: 0.2938
<keras.src.callbacks.History at 0x79b024bd83a0>
在论文中,BASNet是在DUTS-TR数据集上进行训练的,该数据集包含10553张图像。模型训练了400k次迭代,批量大小为八,且没有验证数据集。训练后,该模型在DUTS-TE数据集上进行了评估,并达到了0.042
的平均绝对错误。
由于BASNet是一个深度模型,无法在短时间内完成训练,这是keras示例笔记本的要求,因此我们将从这里加载预训练权重以显示模型预测。由于计算机性能的限制,该模型训练了120k次迭代,但仍然展示了其能力。有关于训练参数的更多详情,请查看给定链接。
!!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg
def normalize_output(prediction):
max_value = np.max(prediction)
min_value = np.min(prediction)
return (prediction - min_value) / (max_value - min_value)
# 加载权重。
basnet_model.load_weights("./basnet_weights.h5")
['下载中...',
'来源: https://drive.google.com/uc?id=1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg',
'目标: /content/keras-io/scripts/tmp_3792671/basnet_weights.h5',
'',
' 0% 0.00/436M [00:00<?, ?B/s]',
' 1% 4.72M/436M [00:00<00:25, 16.7MB/s]',
' 4% 17.3M/436M [00:00<00:13, 31.5MB/s]',
' 7% 30.9M/436M [00:00<00:07, 54.5MB/s]',
' 9% 38.8M/436M [00:00<00:08, 48.2MB/s]',
' 12% 50.9M/436M [00:01<00:08, 45.2MB/s]',
' 15% 65.0M/436M [00:01<00:05, 62.2MB/s]',
' 17% 73.4M/436M [00:01<00:07, 50.6MB/s]',
' 19% 84.4M/436M [00:01<00:07, 48.3MB/s]',
' 23% 100M/436M [00:01<00:05, 66.7MB/s] ',
' 25% 110M/436M [00:02<00:05, 59.1MB/s]',
' 27% 118M/436M [00:02<00:06, 48.4MB/s]',
' 31% 135M/436M [00:02<00:05, 52.7MB/s]',
' 35% 152M/436M [00:02<00:04, 70.2MB/s]',
' 37% 161M/436M [00:03<00:04, 56.9MB/s]',
' 42% 185M/436M [00:03<00:04, 56.2MB/s]',
' 48% 210M/436M [00:03<00:03, 65.0MB/s]',
' 53% 231M/436M [00:03<00:02, 83.6MB/s]',
' 56% 243M/436M [00:04<00:02, 71.4MB/s]',
' 60% 261M/436M [00:04<00:02, 73.9MB/s]',
' 62% 272M/436M [00:04<00:02, 80.1MB/s]',
' 66% 286M/436M [00:04<00:01, 79.3MB/s]',
' 68% 295M/436M [00:04<00:01, 81.2MB/s]',
' 71% 308M/436M [00:04<00:01, 91.3MB/s]',
' 73% 319M/436M [00:04<00:01, 88.2MB/s]',
' 75% 329M/436M [00:05<00:01, 83.5MB/s]',
' 78% 339M/436M [00:05<00:01, 87.6MB/s]',
' 81% 353M/436M [00:05<00:00, 90.4MB/s]',
' 83% 362M/436M [00:05<00:00, 87.0MB/s]',
' 87% 378M/436M [00:05<00:00, 104MB/s] ',
' 89% 389M/436M [00:05<00:00, 101MB/s]',
' 93% 405M/436M [00:05<00:00, 115MB/s]',
' 96% 417M/436M [00:05<00:00, 110MB/s]',
' 98% 428M/436M [00:06<00:00, 91.4MB/s]',
'100% 436M/436M [00:06<00:00, 71.3MB/s]']
for image, mask in val_dataset.take(1):
pred_mask = basnet_model.predict(image)
display([image[0], mask[0], normalize_output(pred_mask[0][0])])
1/1 [==============================] - 2s 2s/step