作者: Abheesht Sharma
创建日期: 2023/07/08
最后修改日期: 2024/03/20
描述: 使用KerasNLP对BART进行微调以执行抽象摘要任务。
在信息过载的时代,从长文档或对话中提取要点并用几句话表达出来变得至关重要。由于摘要在不同领域具有广泛的应用,因此近年来它已成为一个关键的、研究较多的NLP任务。
双向自回归转换器(BART) 是一个基于转换器的编码器-解码器模型,通常用于 序列到序列的任务,如摘要和神经机器翻译。 BART以自我监督的方式在大型文本语料库上进行预训练。在 预训练期间,文本会被损坏,BART会被训练去重建 原始文本(因此被称为“去噪自编码器”)。一些预训练任务 包括令牌掩蔽、令牌删除、句子排列(打乱句子并训练BART修复顺序)等。
在这个示例中,我们将演示如何使用KerasNLP对BART进行微调,以处理抽象 摘要任务(在对话上!),并使用微调后的模型生成摘要。
在我们开始实现管道之前,先安装并导入我们需要的所有 库。我们将使用KerasNLP库。我们还需要一些 实用库。
!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q
正在安装构建依赖项 ... [?25l[?25hdone
获取构建轮子的要求 ... [?25l[?25hdone
正在准备元数据 (pyproject.toml) ... [?25l[?25hdone
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB 1.4 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 34.8 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB 30.4 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB 15.1 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB 5.8 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 61.4 MB/s eta 0:00:00
[2K ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB 10.1 MB/s eta 0:00:00
[?25h 为keras-nlp构建轮子(pyproject.toml) ... [?25l[?25hdone
这个例子使用Keras 3,这样可以在
"tensorflow"
、"jax"
或 "torch"
中运行。Keras 3的支持已集成到
KerasNLP中,只需更改"KERAS_BACKEND"
环境变量即可选择
您选择的后端。我们在下面选择JAX后端。
import os
os.environ["KERAS_BACKEND"] = "jax"
导入所有必要的库。
import py7zr
import time
import keras_nlp
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
正在使用JAX后端。
让我们也定义我们的超参数。
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1 # 可以设置为更高的值以获得更好的结果
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40
让我们加载SAMSum数据集。该数据集 包含大约15,000对对话和摘要。
# 下载数据集。
filename = keras.utils.get_file(
"corpus.7z",
origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)
# 解压`.7z`文件。
with py7zr.SevenZipFile(filename, mode="r") as z:
z.extractall(path="/root/tensorflow_datasets/downloads/manual")
# 使用TFDS加载数据。
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
从 https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z 下载数据
2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
下载并准备数据集 未知大小(下载:未知大小,生成:10.71 MiB,总计:10.71 MiB)到 /root/tensorflow_datasets/samsum/1.0.0...
生成拆分...: 0%| | 0/3 [00:00<?, ? 拆分/s]
生成训练示例...: 0%| | 0/14732 [00:00<?, ? 示例/s]
洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...: 0%| | …
生成验证示例...: 0%| | 0/818 [00:00<?, ? 示例/s]
洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...: 0%| …
生成测试示例...: 0%| | 0/819 [00:00<?, ? 示例/s]
洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...: 0%| | 0…
数据集 samsum 已下载并准备好至 /root/tensorflow_datasets/samsum/1.0.0。后续调用将重用此数据。
数据集有两个字段:dialogue
和 summary
。让我们看一个示例。
for dialogue, summary in samsum_ds:
print(dialogue.numpy())
print(summary.numpy())
break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'
我们现在将数据集分批,并仅保留数据集的一个子集用于本示例。对编码器馈送对话, 相应的摘要作为解码器的输入。因此,我们将数据集的格式更改为一个具有两个键的字典:“encoder_text”和“decoder_text”。这就是 keras_nlp.models.BartSeq2SeqLMPreprocessor
所期望的输入格式。
train_ds = (
samsum_ds.map(
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
)
.batch(BATCH_SIZE)
.cache()
)
train_ds = train_ds.take(NUM_BATCHES)
让我们首先加载模型和预处理器。我们使用512和128的序列长度分别用于编码器和解码器,而不是1024(这是默认序列长度)。这将使我们能够快速在Colab上运行此示例。
如果你仔细观察,预处理器与模型关联。这意味着我们不必担心处理文本输入;所有操作将内部完成。预处理器对编码器文本和解码器文本进行标记化,添加特殊标记并填充它们。为了生成自回归训练的标签,预处理器将解码器文本向右移动一个位置。这是因为在每个时间步,模型被训练去预测下一个标记。
preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(
"bart_base_en",
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
"bart_base_en", preprocessor=preprocessor
)
bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json
898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt
456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5
557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
预处理器: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ 分词器 (类型) ┃ 词汇数 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ bart_tokenizer (BartTokenizer) │ 50,265 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
模型: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数数 ┃ 连接到 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ decoder_padding_mask │ (无, 无) │ 0 │ - │ │ (输入层) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ decoder_token_ids │ (无, 无) │ 0 │ - │ │ (输入层) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_padding_mask │ (无, 无) │ 0 │ - │ │ (输入层) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_token_ids │ (无, 无) │ 0 │ - │ │ (输入层) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ bart_backbone (BartBackbone) │ [(无, 无, 768), │ 139,417,344 │ decoder_padding_mask[0][0], │ │ │ (无, 无, 768)] │ │ decoder_token_ids[0][0], │ │ │ │ │ encoder_padding_mask[0][0], │ │ │ │ │ encoder_token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ reverse_embedding │ (无, 50265) │ 38,603,520 │ bart_backbone[0][0] │ │ (反向嵌入) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
总参数: 139,417,344 (4.15 GB)
可训练参数: 139,417,344 (4.15 GB)
不可训练参数: 0 (0.00 B)
定义优化器和损失函数。我们使用带线性衰减学习率的Adam优化器。编译模型。
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
epsilon=1e-6,
global_clipnorm=1.0, # 梯度裁剪。
)
# 将layernorm和偏置项排除在权重衰减之外。
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
bart_lm.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
让我们训练模型!
bart_lm.fit(train_ds, epochs=EPOCHS)
600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330
<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>
模型已经训练完成,现在让我们进入有趣的部分 - 实际生成摘要!我们将从验证集中选取前100个样本并为它们生成摘要。我们将使用默认的解码策略,即贪心搜索。
KerasNLP中的生成高度优化。它得益于XLA的强大支持。其次,自注意力层和解码器中的交叉注意力层中的键/值张量会被缓存,以避免在每个时间步的重新计算。
def generate_text(model, input_text, max_length=200, print_time_taken=False):
start = time.time()
output = model.generate(input_text, max_length=max_length)
end = time.time()
print(f"总耗时: {end - start:.2f}s")
return output
# 加载数据集。
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)
dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
dialogues.append(dialogue.numpy())
ground_truth_summaries.append(summary.numpy())
# 让我们进行一个虚拟调用 - 第一次调用XLA通常需要稍长时间。
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
# 生成摘要。
generated_summaries = generate_text(
bart_lm,
val_ds.map(lambda dialogue, _: dialogue).batch(8),
max_length=MAX_GENERATION_LENGTH,
print_time_taken=True,
)
总耗时: 21.22s
总耗时: 49.00s
让我们看看一些摘要。
for dialogue, generated_summary, ground_truth_summary in zip(
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
print("对话:", dialogue)
print("生成的摘要:", generated_summary)
print("真实摘要:", ground_truth_summary)
print("=============================")
生成的摘要看起来很棒!对于仅训练了1个周期并在5000个示例上进行训练的模型来说,这表现得不错 :)