解释VGG16在ImageNet上的中间层(PyTorch)
解释预测结果相对于原始输入图像比解释预测结果相对于更高卷积层更困难(因为更高卷积层更接近输出)。本笔记本提供了一个简单示例,展示如何使用 GradientExplainer 来解释预训练 VGG16 网络第7层相对于模型输出的情况。
请注意,默认情况下会抽取200个样本以计算期望值。为了运行更快,您可以减少每次解释的样本数量。
[1]:
import json
import numpy as np
import torch
from torchvision import models
import shap
[2]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def normalize(image):
if image.max() > 1:
image /= 255
image = (image - mean) / std
# in addition, roll the axis so that they suit pytorch
return torch.tensor(image.swapaxes(-1, 1).swapaxes(2, 3)).float()
[3]:
# load the model
model = models.vgg16(pretrained=True).eval()
X, y = shap.datasets.imagenet50()
X /= 255
to_explain = X[[39, 41]]
# load the ImageNet class names
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
fname = shap.datasets.cache(url)
with open(fname) as f:
class_names = json.load(f)
e = shap.GradientExplainer((model, model.features[7]), normalize(X))
shap_values, indexes = e.shap_values(
normalize(to_explain), ranked_outputs=2, nsamples=200
)
# get the names for the classes
index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)
# plot the explanations
shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values]
shap.image_plot(shap_values, to_explain, index_names)
用局部平滑解释
梯度解释器使用预期梯度,它将积分梯度、SHAP 和 SmoothGrad 的思想合并到一个单一的期望方程中。要像 SmoothGrad 那样使用平滑,只需将 local_smoothing 参数设置为非零值。这将在期望计算期间向输入添加具有该标准偏差的正态分布噪声。它可以创建更平滑的特征归属,更好地捕捉图像的相关区域。
[4]:
# note that because the inputs are scaled to be between 0 and 1, the local smoothing also has to be
# scaled compared to the Keras model
explainer = shap.GradientExplainer(
(model, model.features[7]), normalize(X), local_smoothing=0.5
)
shap_values, indexes = explainer.shap_values(
normalize(to_explain), ranked_outputs=2, nsamples=200
)
# get the names for the classes
index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)
# plot the explanations
shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values]
shap.image_plot(shap_values, to_explain, index_names)