简单内核 SHAP

本笔记本提供了一个简单的 Kernel SHAP 暴力版本,它枚举了整个 \(2^M\) 样本空间。我们还将其与完整的 KernelExplainer 实现进行了比较。请注意,KernelExplainer 对于较大的 \(M\) 值进行采样近似,但对于较小的值它是精确的。

暴力破解核 SHAP

[1]:
import itertools

import numpy as np
import scipy.special


def powerset(iterable):
    s = list(iterable)
    return itertools.chain.from_iterable(
        itertools.combinations(s, r) for r in range(len(s) + 1)
    )


def shapley_kernel(M, s):
    if s == 0 or s == M:
        return 10000
    return (M - 1) / (scipy.special.binom(M, s) * s * (M - s))


def f(X):
    np.random.seed(0)
    beta = np.random.rand(X.shape[-1])
    return np.dot(X, beta) + 10


def kernel_shap(f, x, reference, M):
    X = np.zeros((2**M, M + 1))
    X[:, -1] = 1
    weights = np.zeros(2**M)
    V = np.zeros((2**M, M))
    for i in range(2**M):
        V[i, :] = reference

    ws = {}
    for i, s in enumerate(powerset(range(M))):
        s = list(s)
        V[i, s] = x[s]
        X[i, s] = 1
        ws[len(s)] = ws.get(len(s), 0) + shapley_kernel(M, len(s))
        weights[i] = shapley_kernel(M, len(s))
    y = f(V)
    wsq = np.sqrt(weights)
    result = np.linalg.lstsq(wsq[:, None] * X, wsq * y, rcond=None)[0]
    return result


M = 4
np.random.seed(1)
x = np.random.randn(M)
reference = np.zeros(M)
phi = kernel_shap(f, x, reference, M)
base_value = phi[-1]
shap_values = phi[:-1]

print("  reference =", reference)
print("          x =", x)
print("shap_values =", shap_values)
print(" base_value =", base_value)
print("   sum(phi) =", np.sum(phi))
print("       f(x) =", f(x))
  reference = [0. 0. 0. 0.]
          x = [ 1.62434536 -0.61175641 -0.52817175 -1.07296862]
shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
 base_value = 10.000000000000002
   sum(phi) = 9.55093584213122
       f(x) = 9.55093584213122

使用 KernelExplainer

[2]:
import shap

explainer = shap.KernelExplainer(f, np.reshape(reference, (1, len(reference))))
shap_values = explainer.shap_values(x)
print("shap_values =", shap_values)
print("base value =", explainer.expected_value)
shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
base value = 10.0