图像数据解释基准测试:图像多类别分类

本笔记本演示了如何使用基准测试工具来测试图像数据解释器的性能。在这个演示中,我们展示了在图像多类分类模型上使用分区解释器的解释性能。用于评估的指标是“保持正例”和“保持负例”。使用的掩码器是带有Inpaint Telea的图像掩码器。

新的基准测试工具使用新的API,以MaskedModel作为用户导入模型的包装器,并评估输入的掩码值。

[1]:
import json

import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input

import shap
import shap.benchmark as benchmark

加载数据和模型

[2]:
model = ResNet50(weights="imagenet")
X, y = shap.datasets.imagenet50()

类标签映射

[3]:
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
    class_names = [v[1] for v in json.load(file).values()]

定义评分函数

[4]:
def f(x):
    tmp = x.copy()
    if len(tmp.shape) == 2:
        tmp = tmp.reshape(tmp.shape[0], *X[0].shape)
    preprocess_input(tmp)
    return model(tmp)

定义图像遮罩

[5]:
masker = shap.maskers.Image("inpaint_telea", X[0].shape)

创建解释器对象

[6]:
explainer = shap.Explainer(f, masker, output_names=class_names)
explainers.Partition is still in an alpha state, so use with caution...

运行 SHAP 解释

[7]:
shap_values = explainer(
    X[1:3], max_evals=500, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]
)
Partition explainer:  50%|█████████████████                 | 1/2 [00:00<?, ?it/s]
Partition explainer: 3it [03:15, 65.24s/it]

绘制 SHAP 解释

[8]:
shap.image_plot(shap_values)
../../../_images/example_notebooks_benchmarks_image_Image_Multiclass_Classification_Benchmark_Demo_16_0.png

获取输出类索引

[9]:
output = f(X[1:3]).numpy()
num_of_outputs = 4
sorted_indexes = np.argsort(-output, axis=1)
sliced_indexes = np.array(
    [index_list[:num_of_outputs] for index_list in sorted_indexes]
)

定义指标(排序顺序与扰动方法)

[10]:
sort_order = "positive"
perturbation = "keep"

基准解释器

[11]:
sequential_perturbation = benchmark.perturbation.SequentialPerturbation(
    explainer.model, explainer.masker, sort_order, perturbation
)
xs, ys, auc = sequential_perturbation.model_score(
    shap_values, X[1:2], indices=sliced_indexes[0]
)
sequential_perturbation.plot(xs, ys, auc)
../../../_images/example_notebooks_benchmarks_image_Image_Multiclass_Classification_Benchmark_Demo_22_1.png