作者: Md Awsafur Rahman
创建日期: 2023/10/30
最后修改: 2023/10/30
描述: 全局上下文视觉变换器在图像分类中的实现和微调。
!pip install --upgrade keras_cv tensorflow
!pip install --upgrade keras
import keras
from keras_cv.layers import DropPath
from keras import ops
from keras import layers
import tensorflow as tf # 仅用于数据加载
import tensorflow_datasets as tfds # 用于花卉数据集
from skimage.data import chelsea
import matplotlib.pyplot as plt
import numpy as np
在本笔记本中,我们将利用多后端Keras 3.0来实现 GCViT: 全局上下文视觉变换器 论文, 由A Hatamizadeh等人在2023年ICML会议上提出。然后,我们将微调模型以 在花卉数据集上执行图像分类任务,利用官方的ImageNet预训练权重。此笔记本的一个亮点是它与多个后端的兼容性: TensorFlow、PyTorch和JAX,展示了多后端Keras的真正潜力。
注意: 在本节中我们将了解GCViT的背景故事,并尝试 理解为何会提出它。
O(n^2)
] 计算复杂度以及缺乏多尺度信息使得ViT
难以被视为计算机视觉任务的通用架构,如分割和目标检测,其中需要像素级的密集预测。
让我们快速浏览一下我们的关键组件,
1. Stem/PatchEmbed:
一个Stem/patchify层在网络的开始处理图像。
对于这个网络,它创建补丁/令牌并将它们转换为嵌入。
2. Level:
这是提取特征的重复构建模块,使用不同的模块。
3. Global Token Gen./FeatureExtraction:
它使用Depthwise-CNN、SqueezeAndExcitation (Squeeze-Excitation)、CNN和
MaxPooling生成全局令牌/补丁。所以基本上
它是一个特征提取器。
4. Block:
这是一个重复模块,对特征应用注意力并
将它们投影到某个维度。
1. Local-MSA:
本地多头自注意力。
2. Global-MSA:
全局多头自注意力。
3. MLP:
将向量投影到另一个维度的线性层。
5. Downsample/ReduceSize:
这与全局令牌生成模块非常相似,除非它
使用CNN而不是MaxPooling来下采样,并添加层归一化模块。
6. Head:
负责分类任务的模块。
1. Pooling:
将N x 2D
特征转换为N x 1D
特征。
2. Classifier:
处理N x 1D
特征以做出关于类别的决策。
我已经对架构图进行了注释,以便更易于理解,
注意: 这些模块用于构建论文中的其他模块。大多数模块来自其他工作或老旧工作的修改版本。
SqueezeAndExcitation
: 压缩-激励(SE)又称瓶颈模块,充当一种通道注意力。它由平均池化、密集/全连接(FC)/线性、GELU和Sigmoid模块组成。
Fused-MBConv:
这与EfficientNetV2中使用的相似。它使用深度卷积、GELU、压缩和激励、卷积,通过残差连接提取特征。请注意,这里没有声明新的模块,我们只是直接应用了相应的模块。
ReduceSize
: 这是一个基于CNN的下采样模块,使用上面提到的Fused-MBConv
模块提取特征,步幅卷积同时降低空间维度并增加特征的通道维度,最后使用层归一化模块来归一化特征。在论文/图中,此模块被称为下采样模块。我认为值得一提的是,SwinTransformer使用PatchMerging
模块而不是ReduceSize
来减少空间维度并增加通道维度,该模块使用全连接/密集/线性模块。根据GCViT论文,使用ReduceSize
的目的之一是通过CNN模块添加归纳偏差。
MLP:
这是我们自己的多层感知器模块。这个前馈/全连接/线性模块简单地将输入投影到任意维度。
class SqueezeAndExcitation(layers.Layer):
"""挤压与激励块。
Args:
output_dim: 输出特征维度,如果为`None`则使用与输入相同的维度。
expansion: 扩展比率。
"""
def __init__(self, output_dim=None, expansion=0.25, **kwargs):
super().__init__(**kwargs)
self.expansion = expansion
self.output_dim = output_dim
def build(self, input_shape):
inp = input_shape[-1]
self.output_dim = self.output_dim or inp
self.avg_pool = layers.GlobalAvgPool2D(keepdims=True, name="avg_pool")
self.fc = [
layers.Dense(int(inp * self.expansion), use_bias=False, name="fc_0"),
layers.Activation("gelu", name="fc_1"),
layers.Dense(self.output_dim, use_bias=False, name="fc_2"),
layers.Activation("sigmoid", name="fc_3"),
]
super().build(input_shape)
def call(self, inputs, **kwargs):
x = self.avg_pool(inputs)
for layer in self.fc:
x = layer(x)
return x * inputs
class ReduceSize(layers.Layer):
"""下采样块。
Args:
keepdims: 如果为False,空间维度将被减少,通道维度将被增加
"""
def __init__(self, keepdims=False, **kwargs):
super().__init__(**kwargs)
self.keepdims = keepdims
def build(self, input_shape):
embed_dim = input_shape[-1]
dim_out = embed_dim if self.keepdims else 2 * embed_dim
self.pad1 = layers.ZeroPadding2D(1, name="pad1")
self.pad2 = layers.ZeroPadding2D(1, name="pad2")
self.conv = [
layers.DepthwiseConv2D(
kernel_size=3, strides=1, padding="valid", use_bias=False, name="conv_0"
),
layers.Activation("gelu", name="conv_1"),
SqueezeAndExcitation(name="conv_2"),
layers.Conv2D(
embed_dim,
kernel_size=1,
strides=1,
padding="valid",
use_bias=False,
name="conv_3",
),
]
self.reduction = layers.Conv2D(
dim_out,
kernel_size=3,
strides=2,
padding="valid",
use_bias=False,
name="reduction",
)
self.norm1 = layers.LayerNormalization(
-1, 1e-05, name="norm1"
) # 类似于PyTorch的eps
self.norm2 = layers.LayerNormalization(-1, 1e-05, name="norm2")
def call(self, inputs, **kwargs):
x = self.norm1(inputs)
xr = self.pad1(x)
for layer in self.conv:
xr = layer(xr)
x = x + xr
x = self.pad2(x)
x = self.reduction(x)
x = self.norm2(x)
return x
class MLP(layers.Layer):
"""多层感知器(MLP)块。
Args:
hidden_features: 隐藏特征维度。
out_features: 输出特征维度。
activation: 激活函数。
dropout: dropout比率。
"""
def __init__(
self,
hidden_features=None,
out_features=None,
activation="gelu",
dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_features = hidden_features
self.out_features = out_features
self.activation = activation
self.dropout = dropout
def build(self, input_shape):
self.in_features = input_shape[-1]
self.hidden_features = self.hidden_features or self.in_features
self.out_features = self.out_features or self.in_features
self.fc1 = layers.Dense(self.hidden_features, name="fc1")
self.act = layers.Activation(self.activation, name="act")
self.fc2 = layers.Dense(self.out_features, name="fc2")
self.drop1 = layers.Dropout(self.dropout, name="drop1")
self.drop2 = layers.Dropout(self.dropout, name="drop2")
def call(self, inputs, **kwargs):
x = self.fc1(inputs)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
说明: 在代码中,该模块被称为 PatchEmbed,但在论文中,它被称为 Stem。
在模型中,我们首先使用了 patch_embed
模块。让我们尝试理解这个模块。从 call
方法中可以看到,
1. 该模块首先 填充 输入
2. 然后使用 卷积 提取带有嵌入的补丁。
3. 最后,使用 ReduceSize
模块首先通过 卷积 提取特征,但既不减少空间维度也不增加空间维度。
4. 一个重要的点要注意的是,与 ViT 或 SwinTransformer 不同,GCViT 创建了 重叠补丁。我们可以从代码中注意到这一点,Conv2D(self.embed_dim, kernel_size=3, strides=2, name='proj')
。如果我们想要 非重叠 补丁,那么我们应该使用相同的 kernel_size
和 stride
。
5. 该模块将输入的空间维度减少了 4x
。
摘要:图像 → 填充 → 卷积 → (特征提取 + 下采样)
class PatchEmbed(layers.Layer):
"""补丁嵌入块。
参数:
embed_dim: 特征大小维度。
"""
def __init__(self, embed_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
def build(self, input_shape):
self.pad = layers.ZeroPadding2D(1, name="pad")
self.proj = layers.Conv2D(self.embed_dim, 3, 2, name="proj")
self.conv_down = ReduceSize(keepdims=True, name="conv_down")
def call(self, inputs, **kwargs):
x = self.pad(inputs)
x = self.proj(x)
x = self.conv_down(x)
return x
说明: 这是用于施加归纳偏差的两个 CNN 模块之一。
从上面的单元可以看到,在 level
中,我们首先使用了 to_q_global/Global Token Gen./FeatureExtraction
。让我们尝试理解它的工作原理,
FeatureExtract
模块,按照论文我们需要将该模块重复 K
次,其中 K = log2(H/h)
,H = feature_map_height
,W = feature_map_width
。FeatureExtraction:
该层非常类似于 ReduceSize
模块,只是它使用 MaxPooling 模块来减少维度,它不会增加特征维度(通道数)且不使用 LayerNormalization。该模块用于在 Generate Token Gen.
模块中反复生成 全局标记 以支持 全局上下文注意力。(B, H, W, C)
的输入特征图,我们将得到输出形状 (B, h, w, C)
。如果我们将这些全局标记复制给图像中总共 M
个局部窗口,其中 M = (H x W)/(h x w) = num_window
,则输出形状为 (B * M, h, w, C)
。摘要:该模块用于将图像
调整大小
以适配窗口。
class FeatureExtraction(layers.Layer):
"""特征提取块。
Args:
keepdims: 布尔参数,用于保持分辨率。
"""
def __init__(self, keepdims=False, **kwargs):
super().__init__(**kwargs)
self.keepdims = keepdims
def build(self, input_shape):
embed_dim = input_shape[-1]
self.pad1 = layers.ZeroPadding2D(1, name="pad1")
self.pad2 = layers.ZeroPadding2D(1, name="pad2")
self.conv = [
layers.DepthwiseConv2D(3, 1, use_bias=False, name="conv_0"),
layers.Activation("gelu", name="conv_1"),
SqueezeAndExcitation(name="conv_2"),
layers.Conv2D(embed_dim, 1, 1, use_bias=False, name="conv_3"),
]
if not self.keepdims:
self.pool = layers.MaxPool2D(3, 2, name="pool")
super().build(input_shape)
def call(self, inputs, **kwargs):
x = inputs
xr = self.pad1(x)
for layer in self.conv:
xr = layer(xr)
x = x + xr
if not self.keepdims:
x = self.pool(self.pad2(x))
return x
class GlobalQueryGenerator(layers.Layer):
"""全局查询生成器。
Args:
keepdims: 用于保持FeatureExtraction层的维度。
例如,重复log(56/7) = 3个块,输入窗口维度为56,输出窗口维度为7,向下采样比例为2。请参阅GC ViT论文的图5以获取详细信息。
"""
def __init__(self, keepdims=False, **kwargs):
super().__init__(**kwargs)
self.keepdims = keepdims
def build(self, input_shape):
self.to_q_global = [
FeatureExtraction(keepdims, name=f"to_q_global_{i}")
for i, keepdims in enumerate(self.keepdims)
]
super().build(input_shape)
def call(self, inputs, **kwargs):
x = inputs
for layer in self.to_q_global:
x = layer(x)
return x
注意: 这是本文的核心贡献。
从call
方法中可以看出,
1. WindowAttention
模块根据global_query
参数应用局部和全局窗口注意力。
query, key, value
和全局注意力的key, value
。对于全局注意力,它从Global Token Gen.
获取全局查询。从代码中可以注意到,我们将特征或嵌入维度在所有Transformer的头部之间进行划分,以减少计算。
qkv = tf.reshape(qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads])
q_global = tf.repeat(q_global, repeats=B_//B, axis=0)
,这里B_//B
表示图像中的num_windows
。global_query
参数简单地应用local-window-self-attention
或global-window-attention
。从代码中可以注意到,我们将相对位置嵌入与注意力掩码相加,而不是与补丁嵌入相加。
attn = attn + relative_position_bias[tf.newaxis,]
(1, 8, 8, 3)
而窗口token的形状是(1, 4, 4, 3)
)。是的,你说得对,我们不能直接比较它们,因此我们使用Global Token Gen./FeatureExtraction
CNN模块将图像token调整大小以适应窗口token。以下表格应为您提供清晰的比较,模型 | 查询Token | 键值Token | 注意力类型 | 注意力覆盖 |
---|---|---|---|---|
ViT | 图像 | 图像 | 自注意力 | 全局 |
SwinTransformer | 窗口 | 窗口 | 自注意力 | 局部 |
GCViT | 调整大小的图像 | 窗口 | 图像-窗口注意力 | 全局 |
class WindowAttention(layers.Layer):
"""局部窗口注意力.
该实现由
[Liu et al., 2021](https://arxiv.org/abs/2103.14030) 在SwinTransformer中提出.
Args:
window_size: 窗口大小.
num_heads: 注意力头的数量.
global_query: 如果输入包含全局查询
qkv_bias: bool参数,表示查询、关键、值的可学习偏差.
qk_scale: bool参数,用于缩放查询和关键.
attention_dropout: 注意力的丢弃率.
projection_dropout: 输出的丢弃率.
"""
def __init__(
self,
window_size,
num_heads,
global_query,
qkv_bias=True,
qk_scale=None,
attention_dropout=0.0,
projection_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
window_size = (window_size, window_size)
self.window_size = window_size
self.num_heads = num_heads
self.global_query = global_query
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.attention_dropout = attention_dropout
self.projection_dropout = projection_dropout
def build(self, input_shape):
embed_dim = input_shape[0][-1]
head_dim = embed_dim // self.num_heads
self.scale = self.qk_scale or head_dim**-0.5
self.qkv_size = 3 - int(self.global_query)
self.qkv = layers.Dense(
embed_dim * self.qkv_size, use_bias=self.qkv_bias, name="qkv"
)
self.relative_position_bias_table = self.add_weight(
name="relative_position_bias_table",
shape=[
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
self.num_heads,
],
initializer=keras.initializers.TruncatedNormal(stddev=0.02),
trainable=True,
dtype=self.dtype,
)
self.attn_drop = layers.Dropout(self.attention_dropout, name="attn_drop")
self.proj = layers.Dense(embed_dim, name="proj")
self.proj_drop = layers.Dropout(self.projection_dropout, name="proj_drop")
self.softmax = layers.Activation("softmax", name="softmax")
super().build(input_shape)
def get_relative_position_index(self):
coords_h = ops.arange(self.window_size[0])
coords_w = ops.arange(self.window_size[1])
coords = ops.stack(ops.meshgrid(coords_h, coords_w, indexing="ij"), axis=0)
coords_flatten = ops.reshape(coords, [2, -1])
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = ops.transpose(relative_coords, axes=[1, 2, 0])
relative_coords_xx = relative_coords[:, :, 0] + self.window_size[0] - 1
relative_coords_yy = relative_coords[:, :, 1] + self.window_size[1] - 1
relative_coords_xx = relative_coords_xx * (2 * self.window_size[1] - 1)
relative_position_index = relative_coords_xx + relative_coords_yy
return relative_position_index
def call(self, inputs, **kwargs):
if self.global_query:
inputs, q_global = inputs
B = ops.shape(q_global)[0] # B, N, C
else:
inputs = inputs[0]
B_, N, C = ops.shape(inputs) # B*num_window, num_tokens, channels
qkv = self.qkv(inputs)
qkv = ops.reshape(
qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads]
)
qkv = ops.transpose(qkv, [2, 0, 3, 1, 4])
if self.global_query:
k, v = ops.split(
qkv, indices_or_sections=2, axis=0
) # 对于未知的情况,num=None将引发错误
q_global = ops.repeat(
q_global, repeats=B_ // B, axis=0
) # num_windows = B_//B => q_global在一张图片的所有窗口中相同
q = ops.reshape(q_global, [B_, N, self.num_heads, C // self.num_heads])
q = ops.transpose(q, axes=[0, 2, 1, 3])
else:
q, k, v = ops.split(qkv, indices_or_sections=3, axis=0)
q = ops.squeeze(q, axis=0)
k = ops.squeeze(k, axis=0)
v = ops.squeeze(v, axis=0)
q = q * self.scale
attn = q @ ops.transpose(k, axes=[0, 1, 3, 2])
relative_position_bias = ops.take(
self.relative_position_bias_table,
ops.reshape(self.get_relative_position_index(), [-1]),
)
relative_position_bias = ops.reshape(
relative_position_bias,
[
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
],
)
relative_position_bias = ops.transpose(relative_position_bias, axes=[2, 0, 1])
attn = attn + relative_position_bias[None,]
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = ops.transpose((attn @ v), axes=[0, 2, 1, 3])
x = ops.reshape(x, [B_, N, C])
x = self.proj_drop(self.proj(x))
return x
备注: 该模块没有任何卷积模块。
在我们使用的第二个模块 level
是 block
。让我们尝试理解它是如何工作的。如我们从 call
方法中看到的,
1. Block
模块只接受特征图用于局部注意,或额外的全局查询用于全局注意。
2. 在将特征图发送给注意力之前,这个模块将批特征图转换为批窗口,因为我们将应用窗口注意力。
3. 然后我们将批批窗口发送给注意力。
4. 在应用了注意力后,我们将批窗口恢复为批特征图。
5. 在将应用了注意力的特征发送到输出之前,这个模块在残差连接中应用了随机深度正则化。此外,在应用随机深度之前,它用可训练参数重新缩放输入。注意,这个随机深度模块在论文的图中没有显示。
在 block
模块中,我们在应用注意力之前和之后创建了窗口。让我们尝试理解我们是如何创建窗口的,
* 以下模块将特征图 (B, H, W, C)
转换为堆叠窗口 (B x H/h x W/w, h, w, C)
→ (num_windows_batch, window_size, window_size, channel)
* 此模块使用 reshape
和 transpose
从图像中创建这些窗口,而不是对它们进行迭代。
class Block(layers.Layer):
"""GCViT 块。
Args:
window_size: 窗口大小。
num_heads: 注意力头的数量。
global_query: 应用全局窗口注意力。
mlp_ratio: MLP 比率。
qkv_bias: 查询、键、值学习偏置的布尔参数。
qk_scale: 缩放查询、键的布尔参数。
drop: dropout 率。
attention_dropout: 注意力 dropout 率。
path_drop: dropout 路径率。
activation: 激活函数。
layer_scale: 层缩放系数。
"""
def __init__(
self,
window_size,
num_heads,
global_query,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
dropout=0.0,
attention_dropout=0.0,
path_drop=0.0,
activation="gelu",
layer_scale=None,
**kwargs,
):
super().__init__(**kwargs)
self.window_size = window_size
self.num_heads = num_heads
self.global_query = global_query
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.dropout = dropout
self.attention_dropout = attention_dropout
self.path_drop = path_drop
self.activation = activation
self.layer_scale = layer_scale
def build(self, input_shape):
B, H, W, C = input_shape[0]
self.norm1 = layers.LayerNormalization(-1, 1e-05, name="norm1")
self.attn = WindowAttention(
window_size=self.window_size,
num_heads=self.num_heads,
global_query=self.global_query,
qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
attention_dropout=self.attention_dropout,
projection_dropout=self.dropout,
name="attn",
)
self.drop_path1 = DropPath(self.path_drop)
self.drop_path2 = DropPath(self.path_drop)
self.norm2 = layers.LayerNormalization(-1, 1e-05, name="norm2")
self.mlp = MLP(
hidden_features=int(C * self.mlp_ratio),
dropout=self.dropout,
activation=self.activation,
name="mlp",
)
if self.layer_scale is not None:
self.gamma1 = self.add_weight(
name="gamma1",
shape=[C],
initializer=keras.initializers.Constant(self.layer_scale),
trainable=True,
dtype=self.dtype,
)
self.gamma2 = self.add_weight(
name="gamma2",
shape=[C],
initializer=keras.initializers.Constant(self.layer_scale),
trainable=True,
dtype=self.dtype,
)
else:
self.gamma1 = 1.0
self.gamma2 = 1.0
self.num_windows = int(H // self.window_size) * int(W // self.window_size)
super().build(input_shape)
def call(self, inputs, **kwargs):
if self.global_query:
inputs, q_global = inputs
else:
inputs = inputs[0]
B, H, W, C = ops.shape(inputs)
x = self.norm1(inputs)
# 创建窗口并在批次轴上连接它们
x = self.window_partition(x, self.window_size) # (B_, win_h, win_w, C)
# 展平补丁
x = ops.reshape(x, [-1, self.window_size * self.window_size, C])
# 注意力
if self.global_query:
x = self.attn([x, q_global])
else:
x = self.attn([x])
# 逆转窗口分区
x = self.window_reverse(x, self.window_size, H, W, C)
# FFN
x = inputs + self.drop_path1(x * self.gamma1)
x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
return x
def window_partition(self, x, window_size):
"""
Args:
x: (B, H, W, C)
window_size: 窗口大小
Returns:
本地窗口特征 (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = ops.shape(x)
x = ops.reshape(
x,
[
-1,
H // window_size,
window_size,
W // window_size,
window_size,
C,
],
)
x = ops.transpose(x, axes=[0, 1, 3, 2, 4, 5])
windows = ops.reshape(x, [-1, window_size, window_size, C])
return windows
def window_reverse(self, windows, window_size, H, W, C):
"""
Args:
windows: 本地窗口特征 (num_windows*B, window_size, window_size, C)
window_size: 窗口大小
H: 图像高度
W: 图像宽度
C: 图像通道
Returns:
x: (B, H, W, C)
"""
x = ops.reshape(
windows,
[
-1,
H // window_size,
W // window_size,
window_size,
window_size,
C,
],
)
x = ops.transpose(x, axes=[0, 1, 3, 2, 4, 5])
x = ops.reshape(x, [-1, H, W, C])
return x
注意: 该模块同时具有 Transformer 和 CNN 模块。
在模型中,我们使用的第二个模块是 level
。让我们尝试理解这个模块。正如我们从 call
方法中看到的:
1. 首先,它使用一系列 FeatureExtraction
模块创建 global_token。正如我们后面将看到的,FeatureExtraction
其实就是一个简单的 CNN 基模块。
2. 然后,它使用一系列 Block
模块来应用 局部或全局窗口注意力,这取决于深度级别。
3. 最后,它使用 ReduceSize
来减少 上下文化特征 的维数。
摘要:feature_map → global_token → local/global window attention → downsample
class Level(layers.Layer):
"""GCViT level.
Args:
depth: 每个阶段的层数。
num_heads: 每个阶段的头数。
window_size: 每个阶段的窗口大小。
keepdims: 在 FeatureExtraction 中保留的维度。
downsample: 下采样的布尔参数。
mlp_ratio: MLP 比率。
qkv_bias: 查询、键、值可学习偏置的布尔参数。
qk_scale: 查询、键缩放的布尔参数。
drop: dropout 率。
attention_dropout: 注意力 dropout 率。
path_drop: drop path 率。
layer_scale: 层缩放系数。
"""
def __init__(
self,
depth,
num_heads,
window_size,
keepdims,
downsample=True,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
dropout=0.0,
attention_dropout=0.0,
path_drop=0.0,
layer_scale=None,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.num_heads = num_heads
self.window_size = window_size
self.keepdims = keepdims
self.downsample = downsample
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.dropout = dropout
self.attention_dropout = attention_dropout
self.path_drop = path_drop
self.layer_scale = layer_scale
def build(self, input_shape):
path_drop = (
[self.path_drop] * self.depth
if not isinstance(self.path_drop, list)
else self.path_drop
)
self.blocks = [
Block(
window_size=self.window_size,
num_heads=self.num_heads,
global_query=bool(i % 2),
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
path_drop=path_drop[i],
layer_scale=self.layer_scale,
name=f"blocks_{i}",
)
for i in range(self.depth)
]
self.down = ReduceSize(keepdims=False, name="downsample")
self.q_global_gen = GlobalQueryGenerator(self.keepdims, name="q_global_gen")
super().build(input_shape)
def call(self, inputs, **kwargs):
x = inputs
q_global = self.q_global_gen(x) # 形状: (B, win_size, win_size, C)
for i, blk in enumerate(self.blocks):
if i % 2:
x = blk([x, q_global]) # 形状: (B, H, W, C)
else:
x = blk([x]) # 形状: (B, H, W, C)
if self.downsample:
x = self.down(x) # 形状: (B, H//2, W//2, 2*C)
return x
让我们直接跳到模型。正如我们从 call
方法中看到的:
1. 它从图像创建补丁嵌入。该层并不会将这些嵌入展平,这意味着该模块的输出将是
(batch, height/window_size, width/window_size, embed_dim)
,而不是
(batch, height x width/window_size^2, embed_dim)
。
2. 然后它应用 Dropout
模块,随机将输入单元设置为 0。
3. 它将这些嵌入传递给一系列我们称为 level
的 Level
模块,其中,
1. 生成全局令牌
1. 同时应用局部与全局注意力
1. 最后应用下采样。
4. 因此,在 n
个 levels 后的输出,形状为: (batch, width/window_size x 2^{n-1},
width/window_size x 2^{n-1}, embed_dim x 2^{n-1})
。在最后一层,
论文不使用 downsample,而是增加 通道。
5. 上述层的输出使用 LayerNormalization
模块进行归一化。
6. 在头部,2D 特征通过 Pooling
模块转换为 1D 特征。该模块后的输出形状为 (batch, embed_dim x 2^{n-1})
7. 最后,池化特征被送往 Dense/Linear
模块进行分类。
摘要:图像 → (补丁 + 嵌入) → dropout → (注意力 + 特征提取) → 归一化 → 池化 → 分类
class GCViT(keras.Model):
"""GCViT模型。
参数:
window_size: 每个阶段的窗口大小。
embed_dim: 特征尺寸维度。
depths: 每个阶段的层数。
num_heads: 每个阶段的头数。
drop_rate: dropout 率。
mlp_ratio: MLP 比率。
qkv_bias: bool 参数,表示查询、键、值的可学习偏置。
qk_scale: bool 参数,用于缩放查询、键。
attention_dropout: 注意力 dropout 率。
path_drop: drop path 率。
layer_scale: 层缩放系数。
num_classes: 类别数量。
head_activation: 头的激活函数。
"""
def __init__(
self,
window_size,
embed_dim,
depths,
num_heads,
drop_rate=0.0,
mlp_ratio=3.0,
qkv_bias=True,
qk_scale=None,
attention_dropout=0.0,
path_drop=0.1,
layer_scale=None,
num_classes=1000,
head_activation="softmax",
**kwargs,
):
super().__init__(**kwargs)
self.window_size = window_size
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
self.drop_rate = drop_rate
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.attention_dropout = attention_dropout
self.path_drop = path_drop
self.layer_scale = layer_scale
self.num_classes = num_classes
self.head_activation = head_activation
self.patch_embed = PatchEmbed(embed_dim=embed_dim, name="patch_embed")
self.pos_drop = layers.Dropout(drop_rate, name="pos_drop")
path_drops = np.linspace(0.0, path_drop, sum(depths))
keepdims = [(0, 0, 0), (0, 0), (1,), (1,)]
self.levels = []
for i in range(len(depths)):
path_drop = path_drops[sum(depths[:i]) : sum(depths[: i + 1])].tolist()
level = Level(
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
keepdims=keepdims[i],
downsample=(i < len(depths) - 1),
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=drop_rate,
attention_dropout=attention_dropout,
path_drop=path_drop,
layer_scale=layer_scale,
name=f"levels_{i}",
)
self.levels.append(level)
self.norm = layers.LayerNormalization(axis=-1, epsilon=1e-05, name="norm")
self.pool = layers.GlobalAvgPool2D(name="pool")
self.head = layers.Dense(num_classes, name="head", activation=head_activation)
def build(self, input_shape):
super().build(input_shape)
self.built = True
def call(self, inputs, **kwargs):
x = self.patch_embed(inputs) # 形状: (B, H, W, C)
x = self.pos_drop(x)
for level in self.levels:
x = level(x) # 形状: (B, H_, W_, C_)
x = self.norm(x)
x = self.pool(x) # 形状: (B, C__)
x = self.head(x)
return x
def build_graph(self, input_shape=(224, 224, 3)):
"""
参考: https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam
"""
x = keras.Input(shape=input_shape)
return keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
def summary(self, input_shape=(224, 224, 3)):
return self.build_graph(input_shape).summary()
# 模型配置
config = {
"window_size": (7, 7, 14, 7),
"embed_dim": 64,
"depths": (2, 2, 6, 2),
"num_heads": (2, 4, 8, 16),
"mlp_ratio": 3.0,
"path_drop": 0.2,
}
ckpt_link = (
"https://github.com/awsaf49/gcvit-tf/releases/download/v1.1.6/gcvitxxtiny.keras"
)
# 构建模型
model = GCViT(**config)
inp = ops.array(np.random.uniform(size=(1, 224, 224, 3)))
out = model(inp)
# 加载权重
ckpt_path = keras.utils.get_file(ckpt_link.split("/")[-1], ckpt_link)
model.load_weights(ckpt_path)
# 概要
model.summary((224, 224, 3))
正在从 https://github.com/awsaf49/gcvit-tf/releases/download/v1.1.6/gcvitxxtiny.keras 下载数据
48767519/48767519 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
模型: "gc_vi_t"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ 输入层 (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ patch_embed (PatchEmbed) │ (None, 56, 56, 64) │ 45,632 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ pos_drop (Dropout) │ (None, 56, 56, 64) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_0 (Level) │ (None, 28, 28, 128) │ 180,964 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_1 (Level) │ (None, 14, 14, 256) │ 688,456 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_2 (级别) │ (无, 7, 7, 512) │ 5,170,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_3 (级别) │ (无, 7, 7, 512) │ 5,395,744 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ norm (层归一化) │ (无, 7, 7, 512) │ 1,024 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ pool (全局平均池化2D) │ (无, 512) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ head (密集层) │ (无, 1000) │ 513,000 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
总参数: 11,995,428 (45.76 MB)
可训练参数: 11,995,428 (45.76 MB)
非可训练参数: 0 (0.00 B)
img = keras.applications.imagenet_utils.preprocess_input(
chelsea(), mode="torch"
) # 切尔西猫
img = ops.image.resize(img, (224, 224))[None,] # 调整大小并创建批次
pred = model(img)
pred_dec = keras.applications.imagenet_utils.decode_predictions(pred)[0]
print("\n# 图像:")
plt.figure(figsize=(6, 6))
plt.imshow(chelsea())
plt.show()
print()
print("# 预测 (前 5):")
for i in range(5):
print("{:<12} : {:0.2f}".format(pred_dec[i][1], pred_dec[i][2]))
从 https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json 下载数据
35363/35363 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
# 图像:
# 预测 (前 5):
埃及猫 : 0.72
虎猫 : 0.04
虎斑猫 : 0.03
填字游戏 : 0.01
笛子 : 0.00
在接下来的单元中,我们将对 GCViT 模型进行微调,该模型在花卉数据集上进行训练,包含 104
个类别。
# 模型
IMAGE_SIZE = (224, 224)
# 超参数
BATCH_SIZE = 32
EPOCHS = 5
# 数据集
CLASSES = [
"dandelion",
"daisy",
"tulips",
"sunflowers",
"roses",
] # 不要更改顺序
# 其他常量
MEAN = 255 * np.array([0.485, 0.456, 0.406], dtype="float32") # imagenet 均值
STD = 255 * np.array([0.229, 0.224, 0.225], dtype="float32") # imagenet 标准差
AUTO = tf.data.AUTOTUNE
def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):
def preprocess(image, label):
# 训练时进行数据增强
if train:
if tf.random.uniform(shape=[]) > 0.5:
image = tf.image.flip_left_right(image)
image = tf.image.resize(image, size=image_size, method="bicubic")
image = (image - MEAN) / STD # 归一化
return image, label
if train:
dataset = dataset.shuffle(BATCH_SIZE * 10)
return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers",
split=["train[:90%]", "train[90%:]"],
as_supervised=True,
try_gcs=False, # gcs_path 对于 tpu 是必要的,
)
train_dataset = make_dataset(train_dataset, True)
val_dataset = make_dataset(val_dataset, False)
正在下载和准备数据集 218.21 MiB (下载: 218.21 MiB, 生成: 221.83 MiB, 总计: 440.05 MiB) 到 /root/tensorflow_datasets/tf_flowers/3.0.1...
下载完成...: 0%| | 0/5 [00:00<?, ? file/s]
数据集 tf_flowers 已下载并准备好到 /root/tensorflow_datasets/tf_flowers/3.0.1。后续调用将重用此数据。
# 重新构建模型
model = GCViT(**config, num_classes=104)
inp = ops.array(np.random.uniform(size=(1, 224, 224, 3)))
out = model(inp)
# 加载权重
ckpt_path = keras.utils.get_file(ckpt_link.split("/")[-1], ckpt_link)
model.load_weights(ckpt_path, skip_mismatch=True)
model.compile(
loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:269: 用户警告: 总共有 1 个对象无法加载。关于对象 <Dense name=head, built=True> 的示例错误消息:
层 'head' 期望 2 个变量,但在加载时接收到了 0 个变量。预期: ['kernel', 'bias']
无法加载的对象列表:
[<Dense name=head, built=True>]
warnings.warn(msg)
history = model.fit(
train_dataset, validation_data=val_dataset, epochs=EPOCHS, verbose=1
)
第 1 轮/5
104/104 ━━━━━━━━━━━━━━━━━━━━ 153s 581ms/步 - 准确率: 0.5140 - 损失: 1.4615 - 验证准确率: 0.8828 - 验证损失: 0.3485
第 2 轮/5
104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 69ms/步 - 准确率: 0.8775 - 损失: 0.3437 - 验证准确率: 0.8828 - 验证损失: 0.3508
第 3 轮/5
104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/步 - 准确率: 0.8937 - 损失: 0.2918 - 验证准确率: 0.9019 - 验证损失: 0.2953
第 4 轮/5
104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/步 - 准确率: 0.9232 - 损失: 0.2397 - 验证准确率: 0.9183 - 验证损失: 0.2212
第 5 轮/5
104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/步 - 准确率: 0.9456 - 损失: 0.1645 - 验证准确率: 0.9210 - 验证损失: 0.2897