create_counterfactual: 通过反事实来解释模型
Wachter等人于2017年提出的一种反事实方法的实现,用于模型可解释性。
from mlxtend.evaluate import create_counterfactual
概述
反事实是解释与蕴含相关场景的实例:“如果不是x,那么就不是y”的假设背景。例如,“如果我没有努力学习,我的成绩会更糟。”
在机器学习的背景下,我们可以认为反事实实例来自训练集中,通过人为改变其特征来改变模型预测。改变训练示例的特征对于理解模型的行为是有用的。
请注意,这种创建反事实的实现是模型无关的,适用于任何支持predict
(并且理想情况下还支持predict_proba
)方法的scikit-learn估计器。
特别是,create_counterfactual
实现了Wachter等人于2017年描述的方法[1]。C. Molnar的《可解释的机器学习》一书中也提供了对此方法的简短而良好的描述[2]。
简而言之,Wachter等人的方法最小化损失
$$L\left(x, x^{\prime}, y^{\prime}, \lambda\right)=\lambda \cdot\left(\hat{f}\left(x^{\prime}\right)-y^{\prime}\right)^{2}+d\left(x, x^{\prime}\right).$$
左边的项,$\lambda \cdot\left(\hat{f}\left(x^{\prime}\right)-y^{\prime}\right)^{2}$,最小化反事实$x'$的模型预测,即$\hat{f}\left(x^{\prime}\right)$与所需预测(由用户指定)$y^{\prime}$之间的平方差。请注意,$\lambda$是一个超参数,用于权衡这个左边项相对于第二个项$d\left(x, x^{\prime}\right)$的重要性。
第二个项$d\left(x, x^{\prime}\right)$计算给定实例$x$和生成的反事实$x'$之间的距离。简而言之,第二个项将生成的反事实保持与实例相似。相反,第一个项则最大化反事实的模型预测与所需预测(例如,另一个类别标签)之间的差异。
距离函数实现为每个特征维度的绝对差异,按绝对中位数偏差(MAD)缩放:
$$d\left(x, x^{\prime}\right)=\sum_{j=1}^{p} \frac{\left|x_{j}-x_{j}^{\prime}\right|}{M A D_{j}}.$$
MAD衡量给定特征的分布,使用中位数作为其中心:
$$MAD_{j}=\operatorname{median}{i \in{1, \ldots, n}}\left(\left|x{i, j}-\operatorname{median}{l \in{1, \ldots, n}}\left(x{l, j}\right)\right|\right).$$
使用 create_counterfactual
函数的一般步骤如下:
- 选择一个您想要解释的实例,并指定该实例的期望预测 $y'$(这通常与其原始预测不同)。
- 为超参数 $\lambda$ 选择一个值。
- 使用
create_counterfactual
函数优化损失 $L$。 - 可选地,正如作者所推荐的那样,您可以通过增加 $\lambda$ 重复步骤 2 和 3,直到达到用户定义的阈值 $\epsilon$,即:
- 当 $\left|\hat{f}\left(x^{\prime}\right)-y^{\prime}\right|>\epsilon$ 时:
- 增加 $\lambda$
- 当 $\left|\hat{f}\left(x^{\prime}\right)-y^{\prime}\right|>\epsilon$ 时:
参考文献
- [1] Wachter, S., Mittelstadt, B., & Russell, C. (2017). 不打开黑箱的反事实解释:自动化决策与GDPR. 哈佛法律与技术期刊, 31, 841., https://arxiv.org/abs/1711.00399
- [2] Christoph Molnar (2018). 可解释的机器学习, 第6.1章
例子 1 -- 简单的鸢尾花例子
为了简单起见,本示例说明了如何使用 create_counterfactual
函数来解释来自鸢尾花数据集的数据实例。
假设我们在鸢尾花数据集上训练了一个逻辑回归模型,并选择第16个训练点,我们希望通过反事实来解释该点的预测。
from mlxtend.data import iris_data
from sklearn.linear_model import LogisticRegression
X, y = iris_data()
clf = LogisticRegression()
clf.fit(X, y)
x_ref = X[15]
print('True label:', y[15])
print('Predicted label:', clf.predict(x_ref.reshape(1, -1))[0])
print('Predicted probas:', clf.predict_proba(x_ref.reshape(1, -1)))
print('Predicted probability for label 0:', clf.predict_proba(x_ref.reshape(1, -1))[0][0])
True label: 0
Predicted label: 0
Predicted probas: [[9.86677291e-01 1.33226960e-02 1.28980184e-08]]
Predicted probability for label 0: 0.9866772910539873
我们可以看到,上面对于类别0的预测分数是98.6%的概率。现在,我们将通过设置 y_desired=2
来将预测推向类别2。此外,我们通过 y_desired_proba=1
将类别2的概率设置为100%。
from mlxtend.evaluate import create_counterfactual
res = create_counterfactual(x_reference=x_ref,
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.,
lammbda=1, # 超参数
random_seed=123)
print('Features of the 16th training example:', x_ref)
print('Features of the countefactual:', res)
print('Predictions for counterfactual:\n')
print('Predicted label:', clf.predict(res.reshape(1, -1))[0])
print('Predicted probas:', clf.predict_proba(res.reshape(1, -1)))
Features of the 16th training example: [5.7 4.4 1.5 0.4]
Features of the countefactual: [5.72271344 3.99169005 6.45305374 0.40000002]
Predictions for counterfactual:
Predicted label: 2
Predicted probas: [[1.41639932e-04 3.13292297e-01 6.86566063e-01]]
As we can see above, the counterfactual is relatively similar to the original training example, i.e, only the 3rd feature has changed substantially (from 1.5 to 6.45). The predicted label has changed from class 0 t class 2.
Interpretation-wise, this means increasing the petal length of a Iris-setosa flower may make it more similar to a Iris-virginica flower.
示例 2 -- 简单的鸢尾花示例,带有决策区域和阈值停止准则
这个例子类似于示例1;然而,它基于一个二维鸢尾花数据集,仅包含花瓣长度和花瓣宽度特征,以便结果可以通过决策区域图进行绘制。
from mlxtend.plotting import plot_decision_regions
import matplotlib.pyplot as plt
X, y = iris_data()
X = X[:, 2:]
clf = LogisticRegression()
clf.fit(X, y)
LogisticRegression()
# 绘制决策区域
ax = plot_decision_regions(X, y, clf=clf, legend=2)
scatter_highlight_defaults = {'c': 'red',
'edgecolor': 'yellow',
'alpha': 1.0,
'linewidths': 2,
'marker': 'o',
's': 80}
ax.scatter(*X[15],
**scatter_highlight_defaults)
plt.show()
上面图中突出显示的大点表示第16个训练数据点。
下面的代码将创建一个与示例1中相同设置的反事实:
counterfact = create_counterfactual(x_reference=X[15],
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.0,
lammbda=1,
random_seed=123)
ax = plot_decision_regions(X, y, clf=clf, legend=2)
ax.scatter(*counterfact,
**scatter_highlight_defaults)
plt.show()
正如我们上面所看到的,反事实主要沿着x轴(花瓣长度)移动,从而使得参考点与反事实之间的预测从类别0变为类别2。
以下图表基于使用不同的lambda值重复该过程得出:
for i in [0.4, 0.5, 1.0, 5.0, 100]:
counterfact = create_counterfactual(x_reference=X[15],
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.0,
lammbda=i,
random_seed=123)
ax = plot_decision_regions(X, y, clf=clf, legend=2)
ax.scatter(*counterfact,
**scatter_highlight_defaults)
plt.show()
正如我们所看到的,$\lambda$ 值越大,损失中的第一个项
$$L\left(x, x^{\prime}, y^{\prime}, \lambda\right)=\lambda \cdot\left(\hat{f}\left(x^{\prime}\right)-y^{\prime}\right)^{2}+d\left(x, x^{\prime}\right).$$
的主导作用越强。
应用Wachter等人的阈值概念,
- 可选地,正如作者所建议的,您可以通过增加 $\lambda$ 重复步骤 2 和 3,直到达到用户定义的阈值 $\epsilon$,即,
- 当 $\left|\hat{f}\left(x^{\prime}\right)-y^{\prime}\right|>\epsilon$ 时:
- 增加 $\lambda$
我们可以定义一个用户定义的阈值并实现如下:
import numpy as np
desired_class_2_proba = 1.0
for i in np.arange(0, 10000, 0.1):
counterfact = create_counterfactual(x_reference=X[15],
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=desired_class_2_proba,
lammbda=i,
random_seed=123)
predicted_class_2_proba = clf.predict_proba(counterfact.reshape(1, -1))[0][2]
if not i:
print('Initial lambda:', i)
print('Initial diff:', np.abs(predicted_class_2_proba - desired_class_2_proba))
if not np.abs(predicted_class_2_proba - desired_class_2_proba) > 0.3:
break
ax = plot_decision_regions(X, y, clf=clf, legend=2)
ax.scatter(*counterfact,
**scatter_highlight_defaults)
print('Final lambda:', i)
print('Final diff:', np.abs(predicted_class_2_proba - desired_class_2_proba))
plt.show()
Initial lambda: 0.0
Initial diff: 0.9999998976132334
Final lambda: 1.1
Final diff: 0.2962621523225484
API
create_counterfactual(x_reference, y_desired, model, X_dataset, y_desired_proba=None, lammbda=0.1, random_seed=None)
Implementation of the counterfactual method by Wachter et al.
References:
- Wachter, S., Mittelstadt, B., & Russell, C. (2017). Counterfactual explanations without opening the black box: Automated decisions and the GDPR. Harv. JL & Tech., 31, 841., https://arxiv.org/abs/1711.00399
Parameters
-
x_reference
: array-like, shape=[m_features]The data instance (training example) to be explained.
-
y_desired
: intThe desired class label for
x_reference
. -
model
: estimatorA (scikit-learn) estimator implementing
.predict()
and/orpredict_proba()
. - Ifmodel
supportspredict_proba()
, then this is used by default for the first loss term,(lambda * model.predict[_proba](x_counterfact) - y_desired[_proba])^2
- Otherwise, method will fall back topredict
. -
X_dataset
: array-like, shape=[n_examples, m_features]A (training) dataset for picking the initial counterfactual as initial value for starting the optimization procedure.
-
y_desired_proba
: float (default: None)A float within the range [0, 1] designating the desired class probability for
y_desired
. - Ify_desired_proba=None
(default), the first loss term is(lambda * model(x_counterfact) - y_desired)^2
wherey_desired
is a class label - Ify_desired_proba
is not None, the first loss term is(lambda * model(x_counterfact) - y_desired_proba)^2
-
lammbda
: Weighting parameter for the first loss term,(lambda * model(x_counterfact) - y_desired[_proba])^2
-
random_seed
: int (default=None)If int, random_seed is the seed used by the random number generator for selecting the inital counterfactual from
X_dataset
.
ython