文本到多类解释:语言建模示例

本笔记本演示了如何为语言模型生成的top-k个下一个词提供解释。在这个演示中,我们使用hugging face提供的预训练gpt2模型(https://huggingface.co/gpt2)来预测top-k个下一个词。通过查看top-k个下一个词,我们将它们视为k个独立的类别,然后学习这些词的解释。因此,我们能够解释输入中对top-k个下一个词的预测概率负责的词的贡献。

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

加载模型和分词器

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

接下来,我们将模型包装在TopKLM模型中,该模型提取前k个下一个词的对数几率,并创建一个文本掩码器,通过初始化mask_token = “…” 并设置collapse_mask_token = True,用于在输入扰动期间填充文本。

[3]:
wrapped_model = shap.models.TopKLM(model, tokenizer, k=100)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

定义数据

在这里,我们设置初始文本,以便我们希望gpt2模型预测下一个单词

[4]:
s = ["In a shocking finding, scientists discovered a herd of unicorns living in a"]

创建解释器对象

[5]:
explainer = shap.Explainer(wrapped_model, masker)

计算 SHAP 值

[6]:
shap_values = explainer(s)

在输入句子的基础上,可视化前k个下一个词的SHAP值

我们现在可以在下面的图表中的“输出文本”下看到gpt2预测的前k个下一个单词,并将鼠标悬停在每个标记上,以理解输入句子中的哪些单词驱动了特定输出单词的生成。

[7]:
shap.plots.text(shap_values)


[0]
outputs
cave
forest
small
desert
tiny
"
remote
zoo
tree
field
house
nest
tropical
lake
large
mountain
farm
group
wild
very
single
barn
jungle
new
valley
world
garden
herd
grass
natural
park
swamp
laboratory
nearby
well
rural
pond
dark
wood
subter
room
lab
cage
huge
New
water
colony
massive
common
state
deep
home
man
mine
human
rock
region
box
river
part
hollow
c
hole
vast
village
different
virtual
city
strange
greenhouse
frozen
shallow
semi
flat
patch
mysterious
local
giant
sub
barren
special
mountainous
mud
cemetery
pod
hive
newly
closed
community
California
place
flooded
prehistoric
sw
high
z
hot
far
1
pasture


-9-12-15-6-3-12.5736-12.5736base value-2.91492-2.91492f cave(inputs)4.799 a 2.87 in 1.266 living 0.741 orns 0.474 of 0.463 unic 0.165 herd 0.138 a 0.115 In -0.304 , -0.304 discovered -0.236 finding -0.2 shocking -0.176 a -0.153 scientists
inputs
0.115
In
0.138
a
-0.2
shocking
-0.236
finding
-0.304
,
-0.153
scientists
-0.304
discovered
-0.176
a
0.165
herd
0.474
of
0.463
unic
0.741
orns
1.266
living
2.87
in
4.799
a

有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!