情感分类多类别示例

本笔记本演示了如何在多类文本分类场景中使用 Partition 解释器。一旦为一组句子计算了 SHAP 值,我们随后会可视化特征对各个类别的贡献。我们使用的文本分类模型是 BERT,它在情感数据集上进行了微调,用于将句子分类为六个类别之一:喜悦、悲伤、愤怒、恐惧、爱和惊讶。

[1]:
import datasets
import pandas as pd
import transformers

import shap

# load the emotion dataset
dataset = datasets.load_dataset("emotion", split="train")
data = pd.DataFrame({"text": dataset["text"], "emotion": dataset["label"]})
Using custom data configuration default
Reusing dataset emotion (/home/slundberg/.cache/huggingface/datasets/emotion/default/0.0.0/aa34462255cd487d04be8387a2d572588f6ceee23f784f37365aa714afeb8fe6)

构建一个 transformers 管道

请注意,我们已经为管道设置了 return_all_scores=True ,因此我们可以观察模型对所有类别的行为,而不仅仅是最高输出的类别。

[2]:
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "nateraw/bert-base-uncased-emotion", use_fast=True
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    "nateraw/bert-base-uncased-emotion"
).cuda()

# build a pipeline object to do predictions
pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

创建一个管道解释器

一个 transformers 管道对象可以直接传递给 shap.Explainer,它会将管道模型包装为 shap.models.TransformersPipeline 模型,并将管道分词器包装为 shap.maskers.Text 掩码器。

[3]:
explainer = shap.Explainer(pred)

计算 SHAP 值

解释器与其解释的模型具有相同的方法签名,因此我们只需传递一个字符串列表,以解释这些分类。

[4]:
shap_values = explainer(data["text"][:3])

可视化所有输出类别的影响

在下面的图中,当你将鼠标悬停在一个输出类上时,你会得到该输出类的解释。当你点击一个输出类名时,该类将成为解释可视化的焦点,直到你点击另一个类。

基值是当整个输入文本被屏蔽时模型输出的值,而 \(f_{output class}(inputs)\) 是模型对完整原始输入的输出。SHAP 值以加法方式解释了每个单词的解屏蔽如何从基值(其中整个输入被屏蔽)到最终预测值改变模型输出的影响。

[5]:
shap.plots.text(shap_values)


[0]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1316720.131672base value0.9964080.996408fsadness(inputs)0.855 humiliated 0.009 didn 0.003 i 0.001 t 0.0 -0.004 feel -0.0
inputs
-0.0
0.003
i
0.009
didn
0.001
t
-0.004
feel
0.855
humiliated
0.0


[1]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1441950.144195base value0.9952920.995292fsadness(inputs)0.599 hopeless 0.28 feeling 0.039 so 0.004 from 0.004 damned 0.002 from 0.002 awake 0.002 i 0.001 to 0.001 who 0.0 go 0.0 0.0 -0.045 hopeful -0.011 cares -0.006 just -0.006 is -0.004 someone -0.003 can -0.002 and -0.002 being -0.002 around -0.001 so
inputs
0.0
0.002
i
-0.003
can
0.0
go
0.004
from
0.28
feeling
0.039
so
0.599
hopeless
0.001
to
-0.001
so
0.004
damned
-0.045
hopeful
-0.006
just
0.002
from
-0.002
being
-0.002
around
-0.004
someone
0.001
who
-0.011
cares
-0.002
and
-0.006
is
0.002
awake
0.0


[2]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.152610.15261base value0.002277240.00227724fsadness(inputs)0.0 i 0.0 0.0 -0.097 greedy -0.019 feel -0.013 grabbing -0.007 im -0.005 a -0.005 to -0.003 post -0.001 wrong -0.0 minute
inputs
0.0
-0.007
im
-0.013
grabbing
-0.005
a
-0.0
minute
-0.005
to
-0.003
post
0.0
i
-0.019
feel
-0.097
greedy
-0.001
wrong
0.0

可视化单一类别的影响

由于 Explanation 对象是可切片的,我们可以切片出仅针对某一输出类别的模型输出,以可视化该类别的模型输出。

[11]:
shap.plots.text(shap_values[:, :, "anger"])


[0]
0.50.30.10.70.90.2789150.278915base value0.001233210.00123321fanger(inputs)0.028 didn 0.015 i 0.008 t -0.199 humiliated -0.13 feel -0.0 -0.0
inputs
-0.0
0.015
i
0.028
didn
0.008
t
-0.13
feel
-0.199
humiliated
-0.0


[1]
0.50.30.10.70.90.2716290.271629base value0.000462820.00046282fanger(inputs)0.015 damned 0.005 from 0.005 to 0.004 so 0.004 around 0.002 i 0.002 being 0.001 is 0.0 -0.097 hopeful -0.08 hopeless -0.045 feeling -0.028 awake -0.021 cares -0.016 so -0.008 someone -0.004 just -0.004 who -0.003 and -0.003 can go -0.001 from -0.0
inputs
-0.0
0.002
i
-0.003 / 2
can go
0.005
from
-0.045
feeling
-0.016
so
-0.08
hopeless
0.005
to
0.004
so
0.015
damned
-0.097
hopeful
-0.004
just
-0.001
from
0.002
being
0.004
around
-0.008
someone
-0.004
who
-0.021
cares
-0.003
and
0.001
is
-0.028
awake
0.0


[2]
0.50.30.10.70.90.2303730.230373base value0.9914620.991462fanger(inputs)0.545 greedy 0.118 wrong 0.07 grabbing 0.023 post 0.015 im 0.006 feel 0.005 minute 0.0 -0.016 to -0.004 i -0.001 a -0.0
inputs
-0.0
0.015
im
0.07
grabbing
-0.001
a
0.005
minute
-0.016
to
0.023
post
-0.004
i
0.006
feel
0.545
greedy
0.118
wrong
0.0

绘制影响特定类别的顶级词汇

除了切片,Explanation 对象还支持一组归约方法。这里我们使用 .mean(0) 来计算所有词对“joy”类别的平均影响。注意,这里我们也在三个示例上进行平均,为了得到更好的总结,您可能希望使用数据集的更大部分。

[12]:
shap.plots.bar(shap_values[:, :, "joy"].mean(0))
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_14_0.png
[13]:
# we can sort the bar chart in decending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_15_0.png
[14]:
# ...or acending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort.flip)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_16_0.png

解释对数几率而非概率

在上面的例子中,我们解释了pipline对象的直接输出,即类概率。有时在log odds空间中工作更有意义,因为在这个空间中自然地进行加法和减法(加法和减法对应于证据信息的加减)。要使用logits,我们可以使用``shap.models.TransformersPipeline``对象的一个参数:

[15]:
logit_explainer = shap.Explainer(
    shap.models.TransformersPipeline(pred, rescale_to_logits=True)
)

logit_shap_values = logit_explainer(data["text"][:3])
shap.plots.text(logit_shap_values)


[0]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.88626-1.88626base value5.625445.62544fsadness(inputs)6.901 humiliated 0.201 feel 0.173 didn 0.16 i 0.076 t 0.0 -0.0
inputs
-0.0
0.16
i
0.173
didn
0.076
t
0.201
feel
6.901
humiliated
0.0


[1]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.78088-1.78088base value5.353885.35388fsadness(inputs)5.914 hopeless 2.741 feeling 0.248 so 0.079 to so 0.063 can go 0.053 damned 0.029 from -1.3 hopeful -0.172 just from -0.135 awake -0.119 cares -0.11 someone who -0.071 being around -0.054 is -0.025 i -0.006 and -0.0
inputs
-0.025 / 2
i
0.063 / 2
can go
0.029
from
2.741
feeling
0.248
so
5.914
hopeless
0.079 / 2
to so
0.053
damned
-1.3
hopeful
-0.172 / 2
just from
-0.071 / 2
being around
-0.11 / 2
someone who
-0.119
cares
-0.006
and
-0.054
is
-0.135
awake
-0.0


[2]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.71428-1.71428base value-6.08251-6.08251fsadness(inputs)0.212 wrong 0.009 post 0.0 0.0 -3.174 greedy -0.528 feel -0.518 grabbing -0.152 im -0.131 a -0.067 to -0.02 i -0.0 minute
inputs
0.0
-0.152
im
-0.518
grabbing
-0.131
a
-0.0
minute
-0.067
to
0.009
post
-0.02
i
-0.528
feel
-3.174
greedy
0.212
wrong
0.0

有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!