Tree SHAP 的 Python 版本
这是一个用Python编写的Tree SHAP示例实现,便于阅读。
[1]:
import time
import numba
import numpy as np
import sklearn.ensemble
import xgboost
import shap
加载加利福尼亚数据集
[2]:
X, y = shap.datasets.california(n_points=1000)
X.shape
[2]:
(1000, 8)
训练 sklearn 随机森林
[3]:
model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)
model.fit(X, y)
[3]:
RandomForestRegressor(max_depth=4, n_estimators=1000)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestRegressor(max_depth=4, n_estimators=1000)
训练 XGBoost 模型
[4]:
bst = xgboost.train(
{"learning_rate": 0.01, "max_depth": 4}, xgboost.DMatrix(X, label=y), 1000
)
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
Python TreeExplainer
这使用 numba 来加速处理。
[5]:
class TreeExplainer:
def __init__(self, model, **kwargs):
if str(type(model)).endswith(
"sklearn.ensemble._forest.RandomForestRegressor'>"
):
# self.trees = [Tree(e.tree_) for e in model.estimators_]
self.trees = [
Tree(
children_left=e.tree_.children_left,
children_right=e.tree_.children_right,
children_default=e.tree_.children_right,
feature=e.tree_.feature,
threshold=e.tree_.threshold,
value=e.tree_.value[:, 0, 0],
node_sample_weight=e.tree_.weighted_n_node_samples,
)
for e in model.estimators_
]
# Preallocate space for the unique path data
maxd = np.max([t.max_depth for t in self.trees]) + 2
s = (maxd * (maxd + 1)) // 2
self.feature_indexes = np.zeros(s, dtype=np.int32)
self.zero_fractions = np.zeros(s, dtype=np.float64)
self.one_fractions = np.zeros(s, dtype=np.float64)
self.pweights = np.zeros(s, dtype=np.float64)
def shap_values(self, X, **kwargs):
# convert dataframes
if str(type(X)).endswith("pandas.core.series.Series'>"):
X = X.values
elif str(type(X)).endswith("'pandas.core.frame.DataFrame'>"):
X = X.values
assert str(type(X)).endswith("'numpy.ndarray'>"), (
"Unknown instance type: " + str(type(X))
)
assert (
len(X.shape) == 1 or len(X.shape) == 2
), "Instance must have 1 or 2 dimensions!"
# single instance
if len(X.shape) == 1:
phi = np.zeros(X.shape[0] + 1)
x_missing = np.zeros(X.shape[0], dtype=bool)
for t in self.trees:
self.tree_shap(t, X, x_missing, phi)
phi /= len(self.trees)
elif len(X.shape) == 2:
phi = np.zeros((X.shape[0], X.shape[1] + 1))
x_missing = np.zeros(X.shape[1], dtype=bool)
for i in range(X.shape[0]):
for t in self.trees:
self.tree_shap(t, X[i, :], x_missing, phi[i, :])
phi /= len(self.trees)
return phi
def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):
# update the bias term, which is the last index in phi
# (note the paper has this as phi_0 instead of phi_M)
if condition == 0:
phi[-1] += tree.values[0]
# start the recursive algorithm
tree_shap_recursive(
tree.children_left,
tree.children_right,
tree.children_default,
tree.features,
tree.thresholds,
tree.values,
tree.node_sample_weight,
x,
x_missing,
phi,
0,
0,
self.feature_indexes,
self.zero_fractions,
self.one_fractions,
self.pweights,
1,
1,
-1,
condition,
condition_feature,
1,
)
[6]:
# extend our decision path with a fraction of one and zero extensions
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.float64,
numba.types.float64,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
zero_fraction,
one_fraction,
feature_index,
):
feature_indexes[unique_depth] = feature_index
zero_fractions[unique_depth] = zero_fraction
one_fractions[unique_depth] = one_fraction
if unique_depth == 0:
pweights[unique_depth] = 1
else:
pweights[unique_depth] = 0
for i in range(unique_depth - 1, -1, -1):
pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)
pweights[i] = (
zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)
)
# undo a previous extension of the decision path
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwind_path(
feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index
):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = pweights[i]
pweights[i] = (
next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
)
next_one_portion = tmp - pweights[i] * zero_fraction * (
unique_depth - i
) / (unique_depth + 1)
else:
pweights[i] = (pweights[i] * (unique_depth + 1)) / (
zero_fraction * (unique_depth - i)
)
for i in range(path_index, unique_depth):
feature_indexes[i] = feature_indexes[i + 1]
zero_fractions[i] = zero_fractions[i + 1]
one_fractions[i] = one_fractions[i + 1]
# determine what the total permuation weight would be if
# we unwound a previous extension in the decision path
@numba.jit(
numba.types.float64(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwound_path_sum(
feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index
):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
total = 0
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
total += tmp
next_one_portion = pweights[i] - tmp * zero_fraction * (
(unique_depth - i) / (unique_depth + 1)
)
else:
total += (pweights[i] / zero_fraction) / (
(unique_depth - i) / (unique_depth + 1)
)
return total
class Tree:
def __init__(
self,
children_left,
children_right,
children_default,
feature,
threshold,
value,
node_sample_weight,
):
self.children_left = children_left.astype(np.int32)
self.children_right = children_right.astype(np.int32)
self.children_default = children_default.astype(np.int32)
self.features = feature.astype(np.int32)
self.thresholds = threshold
self.values = value
self.node_sample_weight = node_sample_weight
self.max_depth = compute_expectations(
self.children_left,
self.children_right,
self.node_sample_weight,
self.values,
0,
)
@numba.jit(nopython=True)
def compute_expectations(
children_left, children_right, node_sample_weight, values, i, depth=0
):
if children_right[i] == -1:
values[i] = values[i]
return 0
else:
li = children_left[i]
ri = children_right[i]
depth_left = compute_expectations(
children_left, children_right, node_sample_weight, values, li, depth + 1
)
depth_right = compute_expectations(
children_left, children_right, node_sample_weight, values, ri, depth + 1
)
left_weight = node_sample_weight[li]
right_weight = node_sample_weight[ri]
v = (left_weight * values[li] + right_weight * values[ri]) / (
left_weight + right_weight
)
values[i] = v
return max(depth_left, depth_right) + 1
# recursive computation of SHAP values for a decision tree
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.boolean[:],
numba.types.float64[:],
numba.types.int64,
numba.types.int64,
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64,
numba.types.float64,
numba.types.int64,
numba.types.int64,
numba.types.int64,
numba.types.float64,
),
nopython=True,
nogil=True,
)
def tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
node_index,
unique_depth,
parent_feature_indexes,
parent_zero_fractions,
parent_one_fractions,
parent_pweights,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
condition,
condition_feature,
condition_fraction,
):
# stop if we have no weight coming down to us
if condition_fraction == 0:
return
# extend the unique path
feature_indexes = parent_feature_indexes[unique_depth + 1 :]
feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]
zero_fractions = parent_zero_fractions[unique_depth + 1 :]
zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]
one_fractions = parent_one_fractions[unique_depth + 1 :]
one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]
pweights = parent_pweights[unique_depth + 1 :]
pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]
if condition == 0 or condition_feature != parent_feature_index:
extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
)
split_index = features[node_index]
# leaf node
if children_right[node_index] == -1:
for i in range(1, unique_depth + 1):
w = unwound_path_sum(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
i,
)
phi[feature_indexes[i]] += (
w
* (one_fractions[i] - zero_fractions[i])
* values[node_index]
* condition_fraction
)
# internal node
else:
# find which branch is "hot" (meaning x would follow it)
hot_index = 0
cleft = children_left[node_index]
cright = children_right[node_index]
if x_missing[split_index] == 1:
hot_index = children_default[node_index]
elif x[split_index] < thresholds[node_index]:
hot_index = cleft
else:
hot_index = cright
cold_index = cright if hot_index == cleft else cleft
w = node_sample_weight[node_index]
hot_zero_fraction = node_sample_weight[hot_index] / w
cold_zero_fraction = node_sample_weight[cold_index] / w
incoming_zero_fraction = 1
incoming_one_fraction = 1
# see if we have already split on this feature,
# if so we undo that split so we can redo it for this node
path_index = 0
while path_index <= unique_depth:
if feature_indexes[path_index] == split_index:
break
path_index += 1
if path_index != unique_depth + 1:
incoming_zero_fraction = zero_fractions[path_index]
incoming_one_fraction = one_fractions[path_index]
unwind_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
path_index,
)
unique_depth -= 1
# divide up the condition_fraction among the recursive calls
hot_condition_fraction = condition_fraction
cold_condition_fraction = condition_fraction
if condition > 0 and split_index == condition_feature:
cold_condition_fraction = 0
unique_depth -= 1
elif condition < 0 and split_index == condition_feature:
hot_condition_fraction *= hot_zero_fraction
cold_condition_fraction *= cold_zero_fraction
unique_depth -= 1
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
hot_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
hot_zero_fraction * incoming_zero_fraction,
incoming_one_fraction,
split_index,
condition,
condition_feature,
hot_condition_fraction,
)
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
cold_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
cold_zero_fraction * incoming_zero_fraction,
0,
split_index,
condition,
condition_feature,
cold_condition_fraction,
)
比较 XGBoost Tree SHAP 的运行时间…
[7]:
start = time.time()
shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)
print(time.time() - start)
0.1391134262084961
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
与Python (numba) Tree SHAP相比…
[8]:
x = np.ones(X.shape[1])
TreeExplainer(model).shap_values(x)
[8]:
array([-0.5140118 , 0.07819144, 0.09588052, -0.01004643, 0.21487504,
0.37846563, 0.15740923, -0.04002657, 2.04526993])
[9]:
start = time.time()
ex = TreeExplainer(model)
print(time.time() - start)
start = time.time()
ex.shap_values(X.iloc[:, :])
print(time.time() - start)
0.0075643062591552734
10.171732664108276