多输入文本解释:使用Facebook BART的文本蕴含

本笔记本演示了如何获取在mnli数据集上训练的Facebook BART模型输出结果的解释,该模型用于文本蕴含任务。由于shap所需环境中不支持mnli,我们使用了一个来自snli数据集的示例。

BART: https://huggingface.co/facebook/bart-large-mnli

[1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

import shap

加载模型和分词器

[2]:
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[3]:
# load dataset
dataset = load_dataset("snli")
snli_label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
example_ind = 6
premise, hypothesis, label = (
    dataset["train"]["premise"][example_ind],
    dataset["train"]["hypothesis"][example_ind],
    dataset["train"]["label"][example_ind],
)
print("Premise: " + premise)
print("Hypothesis: " + hypothesis)
true_label = snli_label_map[label]
print(f"The true label is: {true_label}")
Reusing dataset snli (C:\Users\v-jocelinsu\.cache\huggingface\datasets\snli\plain_text\1.0.0\bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Premise: A boy is jumping on skateboard in the middle of a red bridge.
Hypothesis: The boy skates down the sidewalk.
The true label is: contradiction
[4]:
# test model
input_ids = tokenizer.encode(premise, hypothesis, return_tensors="pt")
logits = model(input_ids)[0]
probs = logits.softmax(dim=1)

bart_label_map = {0: "contradiction", 1: "neutral", 2: "entailment"}
for i, lab in bart_label_map.items():
    print(f"{lab} probability: {probs[0][i] * 100:0.2f}%")
contradiction probability: 99.95%
neutral probability: 0.03%
entailment probability: 0.02%

运行 shap 值

[5]:
import scipy as sp
import torch


# wrapper function for model
# takes in masked string which is in the form: premise <separator token(s)> hypothesis
def f(x):
    outputs = []
    for _x in x:
        encoding = torch.tensor([tokenizer.encode(_x)])
        output = model(encoding)[0].detach().cpu().numpy()
        outputs.append(output[0])
    outputs = np.array(outputs)
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val
[6]:
# Construct explainer
bart_labels = ["contradiction", "neutral", "entailment"]
explainer = shap.Explainer(f, tokenizer, output_names=bart_labels)
explainers.Partition is still in an alpha state, so use with caution...
[7]:
# encode then decode premise, hypothesis to get concatenated sentences
encoded = tokenizer(premise, hypothesis)["input_ids"][
    1:-1
]  # ignore the start and end tokens, since tokenizer will naturally add them
decoded = tokenizer.decode(encoded)
print(decoded)
A boy is jumping on skateboard in the middle of a red bridge.</s></s>The boy skates down the sidewalk.
[8]:
shap_values = explainer([decoded])  # wrap input in list
print(shap_values)
Partition explainer: 2it [00:17, 18.00s/it]
.values =
array([[[-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [-0.18482581,  0.11629296, -0.05710324],
        [ 0.49238822, -0.37113822, -0.48107536],
        [ 0.49238822, -0.37113822, -0.48107536],
        [ 0.49238822, -0.37113822, -0.48107536],
        [ 0.49238822, -0.37113822, -0.48107536],
        [ 0.65840166, -0.54717401, -0.45434223],
        [ 0.65840166, -0.54717401, -0.45434223],
        [ 0.65840166, -0.54717401, -0.45434223],
        [ 0.65840166, -0.54717401, -0.45434223],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.21500799, -0.29488914,  0.00956938],
        [ 0.21500799, -0.29488914,  0.00956938],
        [ 0.21500799, -0.29488914,  0.00956938],
        [ 0.21500799, -0.29488914,  0.00956938],
        [ 0.88112425, -0.62802847, -0.69218032],
        [ 0.88112425, -0.62802847, -0.69218032],
        [ 1.51606662, -1.12249615, -1.38898808],
        [ 1.51606662, -1.12249615, -1.38898808],
        [ 0.43230298, -0.19067168, -0.23281629],
        [ 0.        ,  0.        ,  0.        ]]])

.base_values =
array([[-1.50853336, -0.49898115, -0.23684637]])

.data =
array([['', 'A ', 'boy ', 'is ', 'jumping ', 'on ', 'skate', 'board ',
        'in ', 'the ', 'middle ', 'of ', 'a ', 'red ', 'bridge', '.',
        '</s>', '</s>', 'The ', 'boy ', 'sk', 'ates ', 'down ', 'the ',
        'sidewalk', '.', '']], dtype='<U8')

解释 可视化

[9]:
shap.plots.text(shap_values)

0th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
A
boy
is
jumping
on
skate
board
in
the
middle
of
a
red
bridge
.
</s>
</s>
The
boy
sk
ates
down
the
sidewalk
.
Output Text
contradiction
neutral
entailment

输入分区树 - 树状图

[10]:
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
[11]:
Z = shap_values[0].abs.clustering
Z[-1][2] = (
    Z[-2][2] + 10
)  # last row's distance is extremely large, so make it a more reasonable value
print(Z)
[[ 0.  1. 12.  2.]
 [ 2.  3. 12.  2.]
 [ 4.  5. 12.  2.]
 [ 6.  7. 12.  2.]
 [ 8.  9. 12.  2.]
 [10. 11. 12.  2.]
 [12. 13. 12.  2.]
 [17. 18. 12.  2.]
 [19. 20. 12.  2.]
 [21. 22. 12.  2.]
 [23. 24. 12.  2.]
 [33. 14. 13.  3.]
 [27. 28. 14.  4.]
 [29. 30. 14.  4.]
 [31. 32. 14.  4.]
 [34. 35. 14.  4.]
 [36. 37. 14.  4.]
 [38. 15. 15.  4.]
 [43. 25. 16.  5.]
 [39. 40. 18.  8.]
 [41. 44. 18.  8.]
 [42. 45. 19.  9.]
 [46. 47. 26. 16.]
 [48. 26. 40. 10.]
 [49. 16. 47. 17.]
 [51. 50. 57. 27.]]
[12]:
labels_arr = shap_values[0].data

# # clean labels of unusal characters (only for slow tokenizer, if use_fast=False)
# labels_arr = []
# for token in shap_values[0].data:
#     if token[0] == 'Ġ':
#         labels_arr.append(token[1:])
#     else:
#         labels_arr.append(token)
print(labels_arr)
['' 'A ' 'boy ' 'is ' 'jumping ' 'on ' 'skate' 'board ' 'in ' 'the '
 'middle ' 'of ' 'a ' 'red ' 'bridge' '.' '</s>' '</s>' 'The ' 'boy ' 'sk'
 'ates ' 'down ' 'the ' 'sidewalk' '.' '']
[13]:
fig = plt.figure(figsize=(len(Z) + 20, 15))
dn = dendrogram(Z, labels=labels_arr)
plt.show()
../../../_images/example_notebooks_text_examples_text_entailment_Textual_Entailment_Explanation_Demo_18_0.png

基准测试

[14]:
sort_order = "positive"
perturbation = "keep"
[15]:
from shap import benchmark
[16]:
sper = benchmark.perturbation.SequentialPerturbation(
    explainer.model, explainer.masker, sort_order, perturbation
)
xs, ys, auc = sper.model_score(shap_values, [decoded])
sper.plot(xs, ys, auc)
../../../_images/example_notebooks_text_examples_text_entailment_Textual_Entailment_Explanation_Demo_22_1.png
[ ]: