作者: Khalid Salama
创建日期: 2020/05/10
最后修改: 2021/02/15
描述: 实现用于文本分类的 Switch Transformer。
本示例演示了用于文本分类的 Switch Transformer 模型的实现。
Switch Transformer 用一种专家混合(MoE)路由层替换了标准 Transformer 中的前馈网络(FFN)层,每个专家独立处理序列中的标记。这允许增加模型大小,而不会增加处理每个示例所需的计算量。
请注意,为了高效地训练 Switch Transformer,需要应用数据和模型并行性,以使专家模块可以同时运行,每个专家在其自己的加速器上。虽然论文中描述的实现使用了 TensorFlow Mesh 框架进行分布式训练, 但本示例展示了一个简单的非分布式 Switch Transformer 模型的实现以供演示。
import keras
from keras import ops
from keras import layers
vocab_size = 20000 # 仅考虑前 20k 个单词
num_tokens_per_example = 200 # 仅考虑每条影评的前 200 个单词
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "训练序列")
print(len(x_val), "验证序列")
x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)
x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)
25000 训练序列
25000 验证序列
embed_dim = 32 # 每个标记的嵌入大小。
num_heads = 2 # 注意力头的数量
ff_dim = 32 # 前馈网络中的隐藏层大小。
num_experts = 10 # Switch Transformer 中使用的专家数量。
batch_size = 50 # 批大小。
learning_rate = 0.001 # 学习率。
dropout_rate = 0.25 # 丢失率。
num_epochs = 3 # 训练轮数。
num_tokens_per_batch = (
batch_size * num_tokens_per_example
) # 每批的总标记数。
print(f"每批的标记数: {num_tokens_per_batch}")
每批的标记数: 10000
它由两个单独的嵌入层组成,一个用于标记,一个用于标记索引(位置)。
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = ops.shape(x)[-1]
positions = ops.arange(start=0, stop=maxlen, step=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
这用于 Switch Transformer 中的专家混合。
def create_feedforward_network(ff_dim, embed_dim, name=None):
return keras.Sequential(
[layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)], name=name
)
这是一种辅助损失,旨在鼓励专家之间的负载均衡。
def load_balanced_loss(router_probs, expert_mask):
# router_probs [tokens_per_batch, num_experts] 是为每个token分配给每个专家的概率。 expert_mask [tokens_per_batch, num_experts] 包含了具有最高路由概率的专家的一热格式。
num_experts = ops.shape(expert_mask)[-1]
# 获取分配给每个专家的token比例。
# density 是一个长度为num experts的向量,和为1。
density = ops.mean(expert_mask, axis=0)
# 从路由器获取分配给每个专家的概率质量比例
# 跨越所有token。 density_proxy 是一个长度为num experts的向量,和为1。
density_proxy = ops.mean(router_probs, axis=0)
# 希望两个向量在所有num_expert元素上都有均匀分配 (1/num experts)。
# 当点积最小化时,这两个向量将被推向均匀分配。
loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")
return loss
class Router(layers.Layer):
def __init__(self, num_experts, expert_capacity):
self.num_experts = num_experts
self.route = layers.Dense(units=num_experts)
self.expert_capacity = expert_capacity
super().__init__()
def call(self, inputs, training=False):
# 输入形状: [每批次的tokens, 嵌入维度]
# 路由日志形状: [每批次的tokens, 专家数量]
router_logits = self.route(inputs)
if training:
# 添加噪声以进行专家之间的探索。
router_logits += keras.random.uniform(
shape=router_logits.shape, minval=0.9, maxval=1.1
)
# 每个token应该发送到哪个专家的概率。
router_probs = keras.activations.softmax(router_logits, axis=-1)
# 获取每个token的top-1专家。 expert_gate是每个token的路由器的top-1概率
# expert_index是每个token将要路由到的专家。
expert_gate, expert_index = ops.top_k(router_probs, k=1)
# expert_mask形状: [每批次的tokens, 专家数量]
expert_mask = ops.one_hot(expert_index, self.num_experts)
# 计算负载均衡损失。
aux_loss = load_balanced_loss(router_probs, expert_mask)
self.add_loss(aux_loss)
# 专家有固定的容量,确保我们不超过它。构建
# 批次索引,确保每个专家接收的例子不超过专家容量。
position_in_expert = ops.cast(
ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
)
# 仅保留符合专家容量的tokens。
expert_mask *= ops.cast(
ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
"float32",
)
expert_mask_flat = ops.sum(expert_mask, axis=-1)
# 去掉超出专家容量的专家。
expert_gate *= expert_mask_flat
# 结合专家输出和与路由概率的缩放。
# combine_tensor形状: [每批次的tokens, 专家数量, 专家容量]
combined_tensor = ops.expand_dims(
expert_gate
* expert_mask_flat
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
# 创建二进制 dispatch_tensor [每批次的tokens, 专家数量, 专家容量]
# 如果token被路由到相应的专家,则为1。
dispatch_tensor = ops.cast(combined_tensor, "float32")
return dispatch_tensor, combined_tensor
class Switch(layers.Layer):
def __init__(
self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1
):
self.num_experts = num_experts
self.embed_dim = embed_dim
self.experts = [
create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
]
self.expert_capacity = num_tokens_per_batch // self.num_experts
self.router = Router(self.num_experts, self.expert_capacity)
super().__init__()
def call(self, inputs):
batch_size = ops.shape(inputs)[0]
num_tokens_per_example = ops.shape(inputs)[1]
# 输入形状: [每批次的tokens数量, 嵌入维度]
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
# dispatch_tensor形状: [专家容量, 专家数量, 每批次的tokens]
# combine_tensor形状: [每批次的tokens, 专家数量, 专家容量]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs形状: [专家数量, 专家容量, 嵌入维度]
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = ops.reshape(
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
)
# 分派给专家
expert_input_list = ops.unstack(expert_inputs, axis=0)
expert_output_list = [
self.experts[idx](expert_input)
for idx, expert_input in enumerate(expert_input_list)
]
# expert_outputs形状: [专家容量, 专家数量, 嵌入维度]
expert_outputs = ops.stack(expert_output_list, axis=1)
# expert_outputs_combined形状: [每批次的tokens, 嵌入维度]
expert_outputs_combined = ops.einsum(
"abc,xba->xc", expert_outputs, combine_tensor
)
# 输出形状: [批次大小, 每个示例的tokens数量, 嵌入维度]
outputs = ops.reshape(
expert_outputs_combined,
[batch_size, num_tokens_per_example, self.embed_dim],
)
return outputs
class TransformerBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
super().__init__()
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
# ffn可以是一个标准的前馈网络或一个具有专家混合的切换层。
self.ffn = ffn
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = layers.Dropout(dropout_rate)
self.dropout2 = layers.Dropout(dropout_rate)
def call(self, inputs, training=False):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
TransformerBlock
层为我们输入序列的每个时间步输出一个向量。
在这里,我们对所有时间步进行平均,并在其上使用一个前馈网络来分类文本。
def create_classifier():
switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)
inputs = layers.Input(shape=(num_tokens_per_example,))
embedding_layer = TokenAndPositionEmbedding(
num_tokens_per_example, vocab_size, embed_dim
)
x = embedding_layer(inputs)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(dropout_rate)(x)
x = layers.Dense(ff_dim, activation="relu")(x)
x = layers.Dropout(dropout_rate)(x)
outputs = layers.Dense(2, activation="softmax")(x)
classifier = keras.Model(inputs=inputs, outputs=outputs)
return classifier
def run_experiment(classifier):
classifier.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
history = classifier.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_data=(x_val, y_val),
)
return history
classifier = create_classifier()
run_experiment(classifier)
Epoch 1/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 - val_accuracy: 0.8748 - val_loss: 1.2891
Epoch 2/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 - val_accuracy: 0.8752 - val_loss: 1.3090
Epoch 3/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 - val_accuracy: 0.8614 - val_loss: 1.3744
<keras.src.callbacks.history.History at 0x7efb79d82a90>
与标准的Transformer架构相比,Switch Transformer可以具有更多的参数,从而增加模型的容量,同时维持合理的计算成本。