文本数据解释基准测试:机器翻译

本笔记本演示了如何使用基准测试工具来评估文本数据解释器的性能。在这个演示中,我们展示了机器翻译模型上分区解释器的解释性能。用于评估的指标是“保持正向”和“保持负向”。使用的掩码器是文本掩码器。

新的基准测试工具使用新的API,将用户导入的模型用MaskedModel作为包装,并评估输入的掩码值。

[1]:
import nlp
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import shap
import shap.benchmark as benchmark

加载数据和模型

[2]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es")
[3]:
dataset = nlp.load_dataset("xsum", split="train")
Using custom data configuration default
[4]:
s = [dataset["summary"][i] for i in range(10)]

创建解释器对象

[5]:
explainer = shap.Explainer(model, tokenizer)
explainers.Partition is still in an alpha state, so use with caution...

运行 SHAP 解释

[6]:
shap_values = explainer(s)
Partition explainer:  10%|███▎                             | 1/10 [00:00<?, ?it/s]
Partition explainer:  30%|███████▌                 | 3/10 [01:10<02:22, 20.37s/it]
Partition explainer:  40%|██████████               | 4/10 [01:50<02:36, 26.09s/it]
Partition explainer:  50%|████████████▌            | 5/10 [02:22<02:20, 28.00s/it]
Partition explainer:  60%|███████████████          | 6/10 [02:41<01:40, 25.19s/it]
Partition explainer:  70%|█████████████████▌       | 7/10 [03:17<01:25, 28.58s/it]
Partition explainer:  80%|████████████████████     | 8/10 [03:45<00:56, 28.42s/it]
Partition explainer:  90%|██████████████████████▌  | 9/10 [04:23<00:31, 31.24s/it]
Partition explainer: 100%|████████████████████████| 10/10 [05:03<00:00, 33.67s/it]
Partition explainer: 11it [05:38, 30.77s/it]

定义指标(排序顺序与扰动方法)

[7]:
sort_order = "positive"
perturbation = "keep"

基准解释器

[8]:
sp = benchmark.perturbation.SequentialPerturbation(
    explainer.model, explainer.masker, sort_order, perturbation
)
xs, ys, auc = sp.model_score(shap_values, s)
sp.plot(xs, ys, auc)
../../../_images/example_notebooks_benchmarks_text_Machine_Translation_Benchmark_Demo_14_1.png
[9]:
sort_order = "negative"
perturbation = "keep"
[10]:
sp = benchmark.perturbation.SequentialPerturbation(
    explainer.model, explainer.masker, sort_order, perturbation
)
xs, ys, auc = sp.model_score(shap_values, s)
sp.plot(xs, ys, auc)
../../../_images/example_notebooks_benchmarks_text_Machine_Translation_Benchmark_Demo_16_1.png