开放式 GPT2 文本生成解释

本笔记本演示了如何为用于开放式文本生成的 gpt2 输出获取解释。在这个演示中,我们使用 hugging face 提供的预训练 gpt2 模型(https://huggingface.co/gpt2)来解释 gpt2 生成的文本。我们进一步展示了如何为自定义生成的文本输出获取解释,并绘制任何输出生成令牌的全局输入令牌重要性。

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

加载模型和分词器

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

下面,我们设置某些模型配置。我们需要定义模型是解码器还是编码器-解码器。这可以通过模型配置文件中的 ‘is_decoder’ 或 ‘is_encoder_decoder’ 参数来设置。我们还可以设置自定义的模型生成参数,这些参数将在输出文本生成解码过程中使用。

[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2,
}

定义初始文本

[4]:
s = ["I enjoy walking with my cute dog"]

创建一个解释器对象并计算SHAP值

[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

可视化 shap 解释

[6]:
shap.plots.text(shap_values)


[0]
outputs
,
but
I
'm
not
sure
if
I
'll
ever
be
able
to


0-2-424-4.04941-4.04941base value-1.27522-1.27522f,(inputs)4.064 dog 0.072 with -0.431 enjoy -0.427 walking -0.238 cute -0.15 my -0.117 I
inputs
-0.117
I
-0.431
enjoy
-0.427
walking
0.072
with
-0.15
my
-0.238
cute
4.064
dog

另一个例子…

[7]:
s = [
    "Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"
]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)


[0]
outputs
in
the
coming
days
.


-3-5-7-11-4.7396-4.7396base value-1.7384-1.7384fin(inputs)1.436 Earth 1.28 collide 0.489 with 0.442 Scientists 0.24 worst 0.221 asteroid 0.093 massive 0.075 outcome -0.385 the -0.365 : -0.168 the -0.166 will -0.156 possible -0.035 confirmed
inputs
0.442
Scientists
-0.035
confirmed
-0.168
the
0.24
worst
-0.156
possible
0.075
outcome
-0.365
:
-0.385
the
0.093
massive
0.221
asteroid
-0.166
will
1.28
collide
0.489
with
1.436
Earth

自定义文本生成和调试偏差输出

下面我们演示如何使用模型解释在给定输入句子的情况下生成特定输出句子的可能性。例如,我们提出一个问题:在句子“I know many people who are [target].”中,哪个国家的居民(目标)在输出句子“They love their vodka!”中生成标记“vodka”的可能性较高?为此,我们首先定义输入-输出句子对

[10]:
# define input
x = [
    "I know many people who are Russian.",
    "I know many people who are Greek.",
    "I know many people who are Australian.",
    "I know many people who are American.",
    "I know many people who are Italian.",
    "I know many people who are Spanish.",
    "I know many people who are German.",
    "I know many people who are Indian.",
]
[11]:
# define output
y = [
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
]

我们将模型包装在一个 Teacher Forcing 评分类中,并创建一个文本掩码器

[12]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

创建一个解释器…

[13]:
explainer = shap.Explainer(teacher_forcing_model, masker)

生成 SHAP 解释值!

[14]:
shap_values = explainer(x, y)

既然我们已经生成了SHAP值,我们可以通过文本图查看输入中的标记对输出句子中标记“vodka”的贡献。注意:红色表示正贡献,蓝色表示负贡献,颜色的强度显示了其在相应方向上的强度。

[15]:
shap.plots.text(shap_values)


[0]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.78452-8.78452fThey(inputs)0.375 . 0.124 people 0.109 are 0.035 who -0.488 Russian -0.377 I -0.158 know -0.157 many
inputs
-0.377
I
-0.158
know
-0.157
many
0.124
people
0.035
who
0.109
are
-0.488
Russian
0.375
.


[1]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.94869-8.94869fThey(inputs)0.387 . 0.149 people 0.144 are 0.054 who -0.716 Greek -0.351 I -0.242 many -0.125 know
inputs
-0.351
I
-0.125
know
-0.242
many
0.149
people
0.054
who
0.144
are
-0.716
Greek
0.387
.


[2]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.67602-8.67602fThey(inputs)0.701 . 0.144 people 0.015 are -0.529 Australian -0.41 I -0.176 many -0.158 know -0.015 who
inputs
-0.41
I
-0.158
know
-0.176
many
0.144
people
-0.015
who
0.015
are
-0.529
Australian
0.701
.


[3]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.14276-9.14276fThey(inputs)0.39 . 0.134 people 0.03 are -0.632 American -0.439 I -0.185 know -0.162 many -0.03 who
inputs
-0.439
I
-0.185
know
-0.162
many
0.134
people
-0.03
who
0.03
are
-0.632
American
0.39
.


[4]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.08274-9.08274fThey(inputs)0.428 . 0.155 are 0.106 people 0.079 who -0.76 Italian -0.454 I -0.24 many -0.149 know
inputs
-0.454
I
-0.149
know
-0.24
many
0.106
people
0.079
who
0.155
are
-0.76
Italian
0.428
.


[5]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.0745-9.0745fThey(inputs)0.414 . 0.288 are 0.156 who 0.106 people -1.015 Spanish -0.399 I -0.225 many -0.15 know
inputs
-0.399
I
-0.15
know
-0.225
many
0.106
people
0.156
who
0.288
are
-1.015
Spanish
0.414
.


[6]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.9994-8.9994fThey(inputs)0.46 . 0.186 are 0.138 people 0.063 who -0.811 German -0.38 I -0.282 many -0.125 know
inputs
-0.38
I
-0.125
know
-0.282
many
0.138
people
0.063
who
0.186
are
-0.811
German
0.46
.


[7]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.63055-8.63055fThey(inputs)0.374 . 0.128 people 0.1 Indian -0.484 I -0.227 know -0.21 many -0.054 who -0.011 are
inputs
-0.484
I
-0.227
know
-0.21
many
0.128
people
-0.054
who
-0.011
are
0.1
Indian
0.374
.

为了查看哪些输入令牌影响(正面/负面)生成单词“vodka”的可能性,我们绘制了单词“vodka”的全局令牌重要性。

瞧!俄罗斯人喜欢他们的伏特加,不是吗? :)

[16]:
shap.plots.bar(shap_values[0, :, "vodka"])
../../../_images/example_notebooks_text_examples_text_generation_Open_Ended_GPT2_Text_Generation_Explanations_30_0.png

有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!