开放式 GPT2 文本生成解释
本笔记本演示了如何为用于开放式文本生成的 gpt2 输出获取解释。在这个演示中,我们使用 hugging face 提供的预训练 gpt2 模型(https://huggingface.co/gpt2)来解释 gpt2 生成的文本。我们进一步展示了如何为自定义生成的文本输出获取解释,并绘制任何输出生成令牌的全局输入令牌重要性。
[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import shap
加载模型和分词器
[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
下面,我们设置某些模型配置。我们需要定义模型是解码器还是编码器-解码器。这可以通过模型配置文件中的 ‘is_decoder’ 或 ‘is_encoder_decoder’ 参数来设置。我们还可以设置自定义的模型生成参数,这些参数将在输出文本生成解码过程中使用。
[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
"do_sample": True,
"max_length": 50,
"temperature": 0.7,
"top_k": 50,
"no_repeat_ngram_size": 2,
}
定义初始文本
[4]:
s = ["I enjoy walking with my cute dog"]
创建一个解释器对象并计算SHAP值
[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
可视化 shap 解释
[6]:
shap.plots.text(shap_values)
[0]
outputs
,
but
I
'm
not
sure
if
I
'll
ever
be
able
to
inputs
-0.117
I
-0.431
enjoy
-0.427
walking
0.072
with
-0.15
my
-0.238
cute
4.064
dog
inputs
-0.117
I
-0.431
enjoy
-0.427
walking
0.072
with
-0.15
my
-0.238
cute
4.064
dog
inputs
0.873
I
1.015
enjoy
-0.029
walking
0.311
with
-0.112
my
0.378
cute
-0.781
dog
inputs
0.873
I
1.015
enjoy
-0.029
walking
0.311
with
-0.112
my
0.378
cute
-0.781
dog
inputs
1.005
I
0.023
enjoy
-0.002
walking
-0.107
with
0.405
my
-0.16
cute
0.016
dog
inputs
1.005
I
0.023
enjoy
-0.002
walking
-0.107
with
0.405
my
-0.16
cute
0.016
dog
inputs
0.126
I
-0.158
enjoy
-0.035
walking
-0.082
with
0.106
my
0.168
cute
-0.068
dog
inputs
0.126
I
-0.158
enjoy
-0.035
walking
-0.082
with
0.106
my
0.168
cute
-0.068
dog
inputs
0.199
I
0.167
enjoy
-0.196
walking
0.018
with
-0.014
my
0.003
cute
-0.041
dog
inputs
0.199
I
0.167
enjoy
-0.196
walking
0.018
with
-0.014
my
0.003
cute
-0.041
dog
inputs
-0.41
I
0.149
enjoy
-0.323
walking
0.408
with
-0.469
my
0.067
cute
0.08
dog
inputs
-0.41
I
0.149
enjoy
-0.323
walking
0.408
with
-0.469
my
0.067
cute
0.08
dog
inputs
-0.064
I
0.036
enjoy
0.188
walking
0.053
with
0.009
my
0.24
cute
0.062
dog
inputs
-0.064
I
0.036
enjoy
0.188
walking
0.053
with
0.009
my
0.24
cute
0.062
dog
inputs
0.406
I
0.356
enjoy
0.171
walking
-0.094
with
0.204
my
-0.016
cute
-0.183
dog
inputs
0.406
I
0.356
enjoy
0.171
walking
-0.094
with
0.204
my
-0.016
cute
-0.183
dog
inputs
-0.13
I
0.457
enjoy
-0.046
walking
-0.005
with
-0.047
my
0.061
cute
0.031
dog
inputs
-0.13
I
0.457
enjoy
-0.046
walking
-0.005
with
-0.047
my
0.061
cute
0.031
dog
inputs
0.013
I
0.012
enjoy
0.232
walking
-0.151
with
-0.038
my
0.07
cute
0.261
dog
inputs
0.013
I
0.012
enjoy
0.232
walking
-0.151
with
-0.038
my
0.07
cute
0.261
dog
inputs
0.078
I
0.016
enjoy
0.392
walking
0.032
with
0.128
my
0.03
cute
-0.005
dog
inputs
0.078
I
0.016
enjoy
0.392
walking
0.032
with
0.128
my
0.03
cute
-0.005
dog
inputs
-0.333
I
-0.023
enjoy
0.203
walking
0.088
with
0.11
my
-0.241
cute
0.145
dog
inputs
-0.333
I
-0.023
enjoy
0.203
walking
0.088
with
0.11
my
-0.241
cute
0.145
dog
inputs
0.089
I
0.058
enjoy
0.171
walking
-0.156
with
0.123
my
0.758
cute
0.421
dog
inputs
0.089
I
0.058
enjoy
0.171
walking
-0.156
with
0.123
my
0.758
cute
0.421
dog
另一个例子…
[7]:
s = [
"Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"
]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)
[0]
outputs
in
the
coming
days
.
inputs
0.442
Scientists
-0.035
confirmed
-0.168
the
0.24
worst
-0.156
possible
0.075
outcome
-0.365
:
-0.385
the
0.093
massive
0.221
asteroid
-0.166
will
1.28
collide
0.489
with
1.436
Earth
inputs
0.442
Scientists
-0.035
confirmed
-0.168
the
0.24
worst
-0.156
possible
0.075
outcome
-0.365
:
-0.385
the
0.093
massive
0.221
asteroid
-0.166
will
1.28
collide
0.489
with
1.436
Earth
inputs
-0.26
Scientists
-0.074
confirmed
0.251
the
-0.062
worst
0.065
possible
-0.117
outcome
-0.009
:
0.308
the
-0.011
massive
-0.203
asteroid
-0.308
will
0.252
collide
0.291
with
-0.145
Earth
inputs
-0.26
Scientists
-0.074
confirmed
0.251
the
-0.062
worst
0.065
possible
-0.117
outcome
-0.009
:
0.308
the
-0.011
massive
-0.203
asteroid
-0.308
will
0.252
collide
0.291
with
-0.145
Earth
inputs
0.325
Scientists
0.202
confirmed
0.054
the
-0.493
worst
0.407
possible
0.318
outcome
-0.351
:
-0.077
the
-0.146
massive
0.207
asteroid
2.257
will
2.382
collide
-0.043
with
0.427
Earth
inputs
0.325
Scientists
0.202
confirmed
0.054
the
-0.493
worst
0.407
possible
0.318
outcome
-0.351
:
-0.077
the
-0.146
massive
0.207
asteroid
2.257
will
2.382
collide
-0.043
with
0.427
Earth
inputs
0.001
Scientists
0.446
confirmed
0.128
the
-0.074
worst
0.107
possible
-0.122
outcome
-0.015
:
-0.098
the
0.06
massive
0.12
asteroid
0.123
will
0.337
collide
0.06
with
-0.411
Earth
inputs
0.001
Scientists
0.446
confirmed
0.128
the
-0.074
worst
0.107
possible
-0.122
outcome
-0.015
:
-0.098
the
0.06
massive
0.12
asteroid
0.123
will
0.337
collide
0.06
with
-0.411
Earth
inputs
-0.578
Scientists
0.36
confirmed
0.012
the
0.043
worst
-0.247
possible
-0.11
outcome
-0.245
:
0.284
the
0.032
massive
0.062
asteroid
-0.973
will
1.145
collide
0.796
with
0.217
Earth
inputs
-0.578
Scientists
0.36
confirmed
0.012
the
0.043
worst
-0.247
possible
-0.11
outcome
-0.245
:
0.284
the
0.032
massive
0.062
asteroid
-0.973
will
1.145
collide
0.796
with
0.217
Earth
inputs
0.016
Scientists
0.292
confirmed
0.12
the
0.172
worst
0.169
possible
0.194
outcome
-0.078
:
0.215
the
0.122
massive
0.089
asteroid
-0.078
will
0.122
collide
0.168
with
-0.109
Earth
inputs
0.016
Scientists
0.292
confirmed
0.12
the
0.172
worst
0.169
possible
0.194
outcome
-0.078
:
0.215
the
0.122
massive
0.089
asteroid
-0.078
will
0.122
collide
0.168
with
-0.109
Earth
自定义文本生成和调试偏差输出
下面我们演示如何使用模型解释在给定输入句子的情况下生成特定输出句子的可能性。例如,我们提出一个问题:在句子“I know many people who are [target].”中,哪个国家的居民(目标)在输出句子“They love their vodka!”中生成标记“vodka”的可能性较高?为此,我们首先定义输入-输出句子对
[10]:
# define input
x = [
"I know many people who are Russian.",
"I know many people who are Greek.",
"I know many people who are Australian.",
"I know many people who are American.",
"I know many people who are Italian.",
"I know many people who are Spanish.",
"I know many people who are German.",
"I know many people who are Indian.",
]
[11]:
# define output
y = [
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
"They love their vodka!",
]
我们将模型包装在一个 Teacher Forcing 评分类中,并创建一个文本掩码器
[12]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
创建一个解释器…
[13]:
explainer = shap.Explainer(teacher_forcing_model, masker)
生成 SHAP 解释值!
[14]:
shap_values = explainer(x, y)
既然我们已经生成了SHAP值,我们可以通过文本图查看输入中的标记对输出句子中标记“vodka”的贡献。注意:红色表示正贡献,蓝色表示负贡献,颜色的强度显示了其在相应方向上的强度。
[15]:
shap.plots.text(shap_values)
[0]
outputs
They
love
their
vodka
!
inputs
-0.377
I
-0.158
know
-0.157
many
0.124
people
0.035
who
0.109
are
-0.488
Russian
0.375
.
inputs
-0.377
I
-0.158
know
-0.157
many
0.124
people
0.035
who
0.109
are
-0.488
Russian
0.375
.
inputs
0.126
I
0.448
know
0.248
many
0.45
people
0.032
who
0.061
are
-0.089
Russian
-0.082
.
inputs
0.126
I
0.448
know
0.248
many
0.45
people
0.032
who
0.061
are
-0.089
Russian
-0.082
.
inputs
-0.069
I
0.088
know
0.297
many
0.144
people
0.175
who
0.253
are
-0.024
Russian
-0.087
.
inputs
-0.069
I
0.088
know
0.297
many
0.144
people
0.175
who
0.253
are
-0.024
Russian
-0.087
.
inputs
0.036
I
-0.013
know
-0.132
many
0.021
people
-0.062
who
-0.164
are
2.648
Russian
0.05
.
inputs
0.036
I
-0.013
know
-0.132
many
0.021
people
-0.062
who
-0.164
are
2.648
Russian
0.05
.
inputs
-0.449
I
-0.182
know
-0.125
many
-0.309
people
-0.122
who
-0.071
are
0.183
Russian
0.202
.
inputs
-0.449
I
-0.182
know
-0.125
many
-0.309
people
-0.122
who
-0.071
are
0.183
Russian
0.202
.
[1]
outputs
They
love
their
vodka
!
inputs
-0.351
I
-0.125
know
-0.242
many
0.149
people
0.054
who
0.144
are
-0.716
Greek
0.387
.
inputs
-0.351
I
-0.125
know
-0.242
many
0.149
people
0.054
who
0.144
are
-0.716
Greek
0.387
.
inputs
0.192
I
0.511
know
0.229
many
0.516
people
-0.004
who
-0.029
are
0.407
Greek
-0.088
.
inputs
0.192
I
0.511
know
0.229
many
0.516
people
-0.004
who
-0.029
are
0.407
Greek
-0.088
.
inputs
-0.044
I
0.076
know
0.277
many
0.147
people
0.169
who
0.339
are
0.141
Greek
-0.106
.
inputs
-0.044
I
0.076
know
0.277
many
0.147
people
0.169
who
0.339
are
0.141
Greek
-0.106
.
inputs
0.011
I
0.001
know
-0.311
many
0.031
people
-0.15
who
-0.445
are
0.162
Greek
0.061
.
inputs
0.011
I
0.001
know
-0.311
many
0.031
people
-0.15
who
-0.445
are
0.162
Greek
0.061
.
inputs
-0.445
I
-0.14
know
-0.125
many
-0.218
people
-0.131
who
-0.041
are
0.339
Greek
0.241
.
inputs
-0.445
I
-0.14
know
-0.125
many
-0.218
people
-0.131
who
-0.041
are
0.339
Greek
0.241
.
[2]
outputs
They
love
their
vodka
!
inputs
-0.41
I
-0.158
know
-0.176
many
0.144
people
-0.015
who
0.015
are
-0.529
Australian
0.701
.
inputs
-0.41
I
-0.158
know
-0.176
many
0.144
people
-0.015
who
0.015
are
-0.529
Australian
0.701
.
inputs
0.148
I
0.457
know
0.248
many
0.453
people
0.032
who
0.042
are
0.365
Australian
-0.057
.
inputs
0.148
I
0.457
know
0.248
many
0.453
people
0.032
who
0.042
are
0.365
Australian
-0.057
.
inputs
-0.031
I
0.115
know
0.32
many
0.177
people
0.184
who
0.298
are
0.089
Australian
-0.053
.
inputs
-0.031
I
0.115
know
0.32
many
0.177
people
0.184
who
0.298
are
0.089
Australian
-0.053
.
inputs
-0.14
I
-0.093
know
-0.265
many
-0.371
people
-0.11
who
-0.648
are
-0.393
Australian
0.123
.
inputs
-0.14
I
-0.093
know
-0.265
many
-0.371
people
-0.11
who
-0.648
are
-0.393
Australian
0.123
.
inputs
-0.455
I
-0.201
know
-0.14
many
-0.315
people
-0.121
who
-0.125
are
0.119
Australian
0.227
.
inputs
-0.455
I
-0.201
know
-0.14
many
-0.315
people
-0.121
who
-0.125
are
0.119
Australian
0.227
.
[3]
outputs
They
love
their
vodka
!
inputs
-0.439
I
-0.185
know
-0.162
many
0.134
people
-0.03
who
0.03
are
-0.632
American
0.39
.
inputs
-0.439
I
-0.185
know
-0.162
many
0.134
people
-0.03
who
0.03
are
-0.632
American
0.39
.
inputs
0.13
I
0.451
know
0.174
many
0.398
people
-0.019
who
-0.072
are
0.474
American
-0.095
.
inputs
0.13
I
0.451
know
0.174
many
0.398
people
-0.019
who
-0.072
are
0.474
American
-0.095
.
inputs
-0.04
I
0.109
know
0.343
many
0.212
people
0.18
who
0.275
are
0.372
American
-0.041
.
inputs
-0.04
I
0.109
know
0.343
many
0.212
people
0.18
who
0.275
are
0.372
American
-0.041
.
inputs
-0.094
I
-0.055
know
-0.366
many
-0.43
people
-0.082
who
-0.514
are
-0.519
American
0.027
.
inputs
-0.094
I
-0.055
know
-0.366
many
-0.43
people
-0.082
who
-0.514
are
-0.519
American
0.027
.
inputs
-0.484
I
-0.182
know
-0.129
many
-0.34
people
-0.116
who
-0.117
are
-0.212
American
0.283
.
inputs
-0.484
I
-0.182
know
-0.129
many
-0.34
people
-0.116
who
-0.117
are
-0.212
American
0.283
.
[4]
outputs
They
love
their
vodka
!
inputs
-0.454
I
-0.149
know
-0.24
many
0.106
people
0.079
who
0.155
are
-0.76
Italian
0.428
.
inputs
-0.454
I
-0.149
know
-0.24
many
0.106
people
0.079
who
0.155
are
-0.76
Italian
0.428
.
inputs
0.138
I
0.485
know
0.258
many
0.472
people
-0.004
who
0.056
are
0.561
Italian
-0.141
.
inputs
0.138
I
0.485
know
0.258
many
0.472
people
-0.004
who
0.056
are
0.561
Italian
-0.141
.
inputs
-0.056
I
0.119
know
0.3
many
0.192
people
0.172
who
0.285
are
0.163
Italian
-0.124
.
inputs
-0.056
I
0.119
know
0.3
many
0.192
people
0.172
who
0.285
are
0.163
Italian
-0.124
.
inputs
-0.012
I
-0.115
know
-0.23
many
-0.142
people
-0.084
who
-0.444
are
0.779
Italian
0.203
.
inputs
-0.012
I
-0.115
know
-0.23
many
-0.142
people
-0.084
who
-0.444
are
0.779
Italian
0.203
.
inputs
-0.467
I
-0.172
know
-0.11
many
-0.266
people
-0.12
who
-0.054
are
0.41
Italian
0.248
.
inputs
-0.467
I
-0.172
know
-0.11
many
-0.266
people
-0.12
who
-0.054
are
0.41
Italian
0.248
.
[5]
outputs
They
love
their
vodka
!
inputs
-0.399
I
-0.15
know
-0.225
many
0.106
people
0.156
who
0.288
are
-1.015
Spanish
0.414
.
inputs
-0.399
I
-0.15
know
-0.225
many
0.106
people
0.156
who
0.288
are
-1.015
Spanish
0.414
.
inputs
0.149
I
0.526
know
0.225
many
0.427
people
-0.003
who
-0.01
are
0.353
Spanish
-0.117
.
inputs
0.149
I
0.526
know
0.225
many
0.427
people
-0.003
who
-0.01
are
0.353
Spanish
-0.117
.
inputs
-0.06
I
0.101
know
0.297
many
0.157
people
0.172
who
0.327
are
0.01
Spanish
-0.081
.
inputs
-0.06
I
0.101
know
0.297
many
0.157
people
0.172
who
0.327
are
0.01
Spanish
-0.081
.
inputs
-0.048
I
-0.099
know
-0.258
many
-0.167
people
-0.103
who
-0.376
are
-0.028
Spanish
0.129
.
inputs
-0.048
I
-0.099
know
-0.258
many
-0.167
people
-0.103
who
-0.376
are
-0.028
Spanish
0.129
.
inputs
-0.482
I
-0.176
know
-0.1
many
-0.276
people
-0.129
who
-0.04
are
0.221
Spanish
0.23
.
inputs
-0.482
I
-0.176
know
-0.1
many
-0.276
people
-0.129
who
-0.04
are
0.221
Spanish
0.23
.
[6]
outputs
They
love
their
vodka
!
inputs
-0.38
I
-0.125
know
-0.282
many
0.138
people
0.063
who
0.186
are
-0.811
German
0.46
.
inputs
-0.38
I
-0.125
know
-0.282
many
0.138
people
0.063
who
0.186
are
-0.811
German
0.46
.
inputs
0.135
I
0.482
know
0.231
many
0.44
people
0.026
who
0.054
are
0.113
German
-0.122
.
inputs
0.135
I
0.482
know
0.231
many
0.44
people
0.026
who
0.054
are
0.113
German
-0.122
.
inputs
-0.059
I
0.133
know
0.317
many
0.205
people
0.201
who
0.294
are
0.229
German
-0.08
.
inputs
-0.059
I
0.133
know
0.317
many
0.205
people
0.201
who
0.294
are
0.229
German
-0.08
.
inputs
-0.079
I
-0.071
know
-0.269
many
-0.182
people
-0.065
who
-0.401
are
0.726
German
0.157
.
inputs
-0.079
I
-0.071
know
-0.269
many
-0.182
people
-0.065
who
-0.401
are
0.726
German
0.157
.
inputs
-0.461
I
-0.171
know
-0.117
many
-0.293
people
-0.135
who
-0.06
are
0.329
German
0.22
.
inputs
-0.461
I
-0.171
know
-0.117
many
-0.293
people
-0.135
who
-0.06
are
0.329
German
0.22
.
[7]
outputs
They
love
their
vodka
!
inputs
-0.484
I
-0.227
know
-0.21
many
0.128
people
-0.054
who
-0.011
are
0.1
Indian
0.374
.
inputs
-0.484
I
-0.227
know
-0.21
many
0.128
people
-0.054
who
-0.011
are
0.1
Indian
0.374
.
inputs
0.111
I
0.487
know
0.202
many
0.438
people
0.006
who
-0.076
are
0.184
Indian
-0.02
.
inputs
0.111
I
0.487
know
0.202
many
0.438
people
0.006
who
-0.076
are
0.184
Indian
-0.02
.
inputs
-0.065
I
0.104
know
0.337
many
0.176
people
0.178
who
0.277
are
0.245
Indian
-0.085
.
inputs
-0.065
I
0.104
know
0.337
many
0.176
people
0.178
who
0.277
are
0.245
Indian
-0.085
.
inputs
-0.07
I
-0.026
know
-0.341
many
-0.216
people
-0.151
who
-0.571
are
-0.666
Indian
0.175
.
inputs
-0.07
I
-0.026
know
-0.341
many
-0.216
people
-0.151
who
-0.571
are
-0.666
Indian
0.175
.
inputs
-0.429
I
-0.175
know
-0.132
many
-0.305
people
-0.088
who
-0.11
are
-0.067
Indian
0.261
.
inputs
-0.429
I
-0.175
know
-0.132
many
-0.305
people
-0.088
who
-0.11
are
-0.067
Indian
0.261
.
为了查看哪些输入令牌影响(正面/负面)生成单词“vodka”的可能性,我们绘制了单词“vodka”的全局令牌重要性。
瞧!俄罗斯人喜欢他们的伏特加,不是吗? :)
[16]:
shap.plots.bar(shap_values[0, :, "vodka"])
有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!