代码示例 / 自然语言处理 / 抽象文本摘要生成与BART

抽象文本摘要生成与BART

作者: Abheesht Sharma
创建日期: 2023/07/08
最后修改日期: 2024/03/20
描述: 使用KerasNLP对BART进行微调以执行抽象摘要任务。

在Colab中查看 GitHub源


介绍

在信息过载的时代,从长文档或对话中提取要点并用几句话表达出来变得至关重要。由于摘要在不同领域具有广泛的应用,因此近年来它已成为一个关键的、研究较多的NLP任务。

双向自回归转换器(BART) 是一个基于转换器的编码器-解码器模型,通常用于 序列到序列的任务,如摘要和神经机器翻译。 BART以自我监督的方式在大型文本语料库上进行预训练。在 预训练期间,文本会被损坏,BART会被训练去重建 原始文本(因此被称为“去噪自编码器”)。一些预训练任务 包括令牌掩蔽、令牌删除、句子排列(打乱句子并训练BART修复顺序)等。

在这个示例中,我们将演示如何使用KerasNLP对BART进行微调,以处理抽象 摘要任务(在对话上!),并使用微调后的模型生成摘要。


设置

在我们开始实现管道之前,先安装并导入我们需要的所有 库。我们将使用KerasNLP库。我们还需要一些 实用库。

!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q
  正在安装构建依赖项 ... [?25l[?25hdone
  获取构建轮子的要求 ... [?25l[?25hdone
  正在准备元数据 (pyproject.toml) ... [?25l[?25hdone
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB 1.4 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 34.8 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB 30.4 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB 15.1 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB 5.8 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 61.4 MB/s eta 0:00:00
[2K     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB 10.1 MB/s eta 0:00:00
[?25h  为keras-nlp构建轮子(pyproject.toml) ... [?25l[?25hdone

这个例子使用Keras 3,这样可以在 "tensorflow""jax""torch" 中运行。Keras 3的支持已集成到 KerasNLP中,只需更改"KERAS_BACKEND"环境变量即可选择 您选择的后端。我们在下面选择JAX后端。

import os

os.environ["KERAS_BACKEND"] = "jax"

导入所有必要的库。

import py7zr
import time

import keras_nlp
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
正在使用JAX后端。

让我们也定义我们的超参数。

BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1  # 可以设置为更高的值以获得更好的结果
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40

数据集

让我们加载SAMSum数据集。该数据集 包含大约15,000对对话和摘要。

# 下载数据集。
filename = keras.utils.get_file(
    "corpus.7z",
    origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)

# 解压`.7z`文件。
with py7zr.SevenZipFile(filename, mode="r") as z:
    z.extractall(path="/root/tensorflow_datasets/downloads/manual")

# 使用TFDS加载数据。
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
从 https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z 下载数据
 2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
下载并准备数据集 未知大小(下载:未知大小,生成:10.71 MiB,总计:10.71 MiB)到 /root/tensorflow_datasets/samsum/1.0.0...

生成拆分...:   0%|          | 0/3 [00:00<?, ? 拆分/s]

生成训练示例...:   0%|          | 0/14732 [00:00<?, ? 示例/s]

洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...:   0%|          | …

生成验证示例...:   0%|          | 0/818 [00:00<?, ? 示例/s]

洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...:   0%|       …

生成测试示例...:   0%|          | 0/819 [00:00<?, ? 示例/s]

洗牌 /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...:   0%|          | 0…

数据集 samsum 已下载并准备好至 /root/tensorflow_datasets/samsum/1.0.0。后续调用将重用此数据。

数据集有两个字段:dialoguesummary。让我们看一个示例。

for dialogue, summary in samsum_ds:
    print(dialogue.numpy())
    print(summary.numpy())
    break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'

我们现在将数据集分批,并仅保留数据集的一个子集用于本示例。对编码器馈送对话, 相应的摘要作为解码器的输入。因此,我们将数据集的格式更改为一个具有两个键的字典:“encoder_text”和“decoder_text”。这就是 keras_nlp.models.BartSeq2SeqLMPreprocessor 所期望的输入格式。

train_ds = (
    samsum_ds.map(
        lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
    )
    .batch(BATCH_SIZE)
    .cache()
)
train_ds = train_ds.take(NUM_BATCHES)

微调 BART

让我们首先加载模型和预处理器。我们使用512和128的序列长度分别用于编码器和解码器,而不是1024(这是默认序列长度)。这将使我们能够快速在Colab上运行此示例。

如果你仔细观察,预处理器与模型关联。这意味着我们不必担心处理文本输入;所有操作将内部完成。预处理器对编码器文本和解码器文本进行标记化,添加特殊标记并填充它们。为了生成自回归训练的标签,预处理器将解码器文本向右移动一个位置。这是因为在每个时间步,模型被训练去预测下一个标记。

preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json
 898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt
 456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5
 557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
预处理器: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ 分词器 (类型)                                                                                词汇数 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ bart_tokenizer (BartTokenizer)                     │                                              50,265 │
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
模型: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ 层 (类型)                   输出形状                   参数数  连接到                   ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ decoder_padding_mask          │ (, )              │           0 │ -                              │
│ (输入层)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ decoder_token_ids             │ (, )              │           0 │ -                              │
│ (输入层)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ encoder_padding_mask          │ (, )              │           0 │ -                              │
│ (输入层)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ encoder_token_ids             │ (, )              │           0 │ -                              │
│ (输入层)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ bart_backbone (BartBackbone)  │ [(, , 768),       │ 139,417,344 │ decoder_padding_mask[0][0],    │
│                               │ (, , 768)]        │             │ decoder_token_ids[0][0],       │
│                               │                           │             │ encoder_padding_mask[0][0],    │
│                               │                           │             │ encoder_token_ids[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ reverse_embedding             │ (, 50265)             │  38,603,520 │ bart_backbone[0][0]            │
│ (反向嵌入)            │                           │             │                                │
└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
 总参数: 139,417,344 (4.15 GB)
 可训练参数: 139,417,344 (4.15 GB)
 不可训练参数: 0 (0.00 B)

定义优化器和损失函数。我们使用带线性衰减学习率的Adam优化器。编译模型。

optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # 梯度裁剪。
)
# 将layernorm和偏置项排除在权重衰减之外。
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

让我们训练模型!

bart_lm.fit(train_ds, epochs=EPOCHS)
 600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330

<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>

生成摘要并评估它们!

模型已经训练完成,现在让我们进入有趣的部分 - 实际生成摘要!我们将从验证集中选取前100个样本并为它们生成摘要。我们将使用默认的解码策略,即贪心搜索。

KerasNLP中的生成高度优化。它得益于XLA的强大支持。其次,自注意力层和解码器中的交叉注意力层中的键/值张量会被缓存,以避免在每个时间步的重新计算。

def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    print(f"总耗时: {end - start:.2f}s")
    return output


# 加载数据集。
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)

dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
    dialogues.append(dialogue.numpy())
    ground_truth_summaries.append(summary.numpy())

# 让我们进行一个虚拟调用 - 第一次调用XLA通常需要稍长时间。
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# 生成摘要。
generated_summaries = generate_text(
    bart_lm,
    val_ds.map(lambda dialogue, _: dialogue).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)
总耗时: 21.22s
总耗时: 49.00s

让我们看看一些摘要。

for dialogue, generated_summary, ground_truth_summary in zip(
    dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
    print("对话:", dialogue)
    print("生成的摘要:", generated_summary)
    print("真实摘要:", ground_truth_summary)
    print("=============================")
对话: b'Tony: 老板在吗?\r\nClaire: 还没有。\r\nTony: 能告诉我他来的时候吗?\r\nClaire: 当然可以。\r\nTony: 谢谢。' 生成摘要: Tony 会在她的老板到达时告诉 Claire。 真实摘要: b"老板还没有到。Claire 会在他来的时候告诉 Tony。" ============================= 对话: b"James: 我该给她买什么?\r\nTim: 谁?\r\nJames: 啊,玛丽,我女朋友\r\nTim: 我真的是你应该问的人吗?\r\nJames: 哦,得了,星期六是她生日\r\nTim: 问桑迪吧\r\nTim: 我真的不是问这个的合适人选\r\nJames: 真是的,好吧!" 生成摘要: 玛丽的女朋友是生日。James 和 Tim 要去问桑迪给她买。 真实摘要: b"玛丽的生日在星期六。她的男朋友 James 在寻找礼物的点子。Tim 建议他去问桑迪。" ============================= 对话: b"Mary: 那以色列怎么样?你去海滩了吗?\r\nKate: 太贵了!但他们说,这是特拉维夫... 我们明天去耶路撒冷。\r\nMary: 我听说以色列很贵,莫妮卡去年去度假时抱怨它有多贵。你打算在死海没死之前去吗?哈哈哈\r\nKate: 哈哈哈,是的,几天后。" 生成摘要: Kate 在特拉维夫度假。Mary 准备在几天后去死海。 真实摘要: b'Mary 和 Kate 讨论以色列的高昂费用。Kate 现在在特拉维夫,计划明天前往耶路撒冷,几天后去死海。' ============================= 对话: b"Giny: 我们有米吗?\r\nRiley: 没有,已经用完了\r\nGiny: 真操蛋!\r\nGiny: 好吧,我去买" 生成摘要: Giny 想从 Riley 那儿买米。 真实摘要: b"Giny 和 Riley 没有剩下任何米了。Giny 会去买一些。" ============================= 对话: b"Jude: 我在十二月初会在华沙,所以我们可以再见面\r\nLeon: !!!\r\nLeon: 不在前面的意思是...? \r\nLeon: 因为我在第一个周末不在这里\r\nJude: 10\r\nJude: 但我觉得是星期一,所以算了,我想 :D\r\nLeon: 嗯,星期一对我来说真的不行 :D\r\nLeon: :<\r\nJude: 哦,好吧下次 :d\r\nLeon: 是啊...!" 生成摘要: Jude 和 Leon 将于这周末在早上10点再见面。 真实摘要: b'Jude 将于十二月十日来到华沙,想见 Leon。Leon 没有时间。'

生成的摘要看起来很棒!对于仅训练了1个周期并在5000个示例上进行训练的模型来说,这表现得不错 :)