首页 DeepExplainer MNIST 示例

一个简单的示例,展示如何使用 DeepExplainer 解释使用 Keras 训练的 MNIST CNN。

[1]:
# this is the code from here --> https://github.com/keras-team/keras/blob/master/examples/demo_mnist_convnet.py
import keras
import numpy as np
from keras import layers
from keras.utils import to_categorical

import shap

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

batch_size = 128
epochs = 3

model = keras.Sequential(
    [
        layers.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 26, 26, 32)     │           320 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 13, 13, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 11, 11, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 5, 5, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 1600)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 1600)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 10)             │        16,010 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 34,826 (136.04 KB)
 Trainable params: 34,826 (136.04 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/3
422/422 ━━━━━━━━━━━━━━━━━━━━ 18s 38ms/step - accuracy: 0.7699 - loss: 0.7657 - val_accuracy: 0.9788 - val_loss: 0.0789
Epoch 2/3
422/422 ━━━━━━━━━━━━━━━━━━━━ 26s 61ms/step - accuracy: 0.9616 - loss: 0.1196 - val_accuracy: 0.9853 - val_loss: 0.0572
Epoch 3/3
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 29ms/step - accuracy: 0.9737 - loss: 0.0857 - val_accuracy: 0.9862 - val_loss: 0.0492
Test loss: 0.047078389674425125
Test accuracy: 0.9847000241279602
[2]:
# select a set of background examples to take an expectation over
background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]

# explain predictions of the model on three images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[0:5])
C:\Users\Tobias Pitters\programming\shap\shap\explainers\_deep\deep_tf.py:99: UserWarning: Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
  warnings.warn("Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 5
      2 background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]
      4 # explain predictions of the model on three images
----> 5 e = shap.DeepExplainer(model, background)
      6 # ...or pass tensors directly
      7 # e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
      8 shap_values = e.shap_values(x_test[0:5])

File ~\programming\shap\shap\explainers\_deep\__init__.py:90, in DeepExplainer.__init__(self, model, data, session, learning_phase_flags)
     87 super().__init__(model, masker)
     89 if framework == 'tensorflow':
---> 90     self.explainer = TFDeep(model, data, session, learning_phase_flags)
     91 elif framework == 'pytorch':
     92     self.explainer = PyTorchDeep(model, data)

File ~\programming\shap\shap\explainers\_deep\deep_tf.py:172, in TFDeep.__init__(self, model, data, session, learning_phase_flags)
    170     self.phi_symbolics = [None]
    171 else:
--> 172     noutputs = self.model_output.shape.as_list()[1]
    173     if noutputs is not None:
    174         self.phi_symbolics = [None for i in range(noutputs)]

AttributeError: 'tuple' object has no attribute 'as_list'
[ ]:
# plot the feature attributions
shap.image_plot(shap_values, -x_test[0:5])
../../../_images/example_notebooks_image_examples_image_classification_Front_Page_DeepExplainer_MNIST_Example_3_0.png

上图显示了五次预测中每个类别的解释。请注意,解释是按类别0-9从左到右沿行排列的。