使用 XGBoost 预测《英雄联盟》胜利

本笔记本使用 Kaggle 数据集 英雄联盟排位赛 ,该数据集包含自2014年以来的180,000场英雄联盟排位赛。我们使用这些数据构建了一个 XGBoost 模型,以根据玩家在比赛中的表现统计数据来预测该玩家所在的队伍是否会获胜。

这里使用的方法适用于任何数据集。我们使用这个数据集来说明SHAP值如何帮助解释诸如XGBoost这样的梯度提升树。由于其规模、交互效应、包含分类和连续特征以及其可解释性(特别是对于游戏玩家),该数据集在多个方面都适合作为一个很好的例子。有关SHAP值的更多信息,请参见:https://github.com/shap/shap

[1]:
from pathlib import Path

import matplotlib.pyplot as pl
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split

import shap

shap.initjs()

加载数据集

要自己运行此程序,您需要从Kaggle下载数据集,并确保下面的 prefix 变量是正确的。为此,请按照上述链接下载并解压数据。如有需要,请更改 prefix 变量。

[2]:
# read in the data
folder_path = Path("../local_scratch/data/league-of-legends-ranked-matches/")
matches = pd.read_csv(folder_path / "matches.csv")
participants = pd.read_csv(folder_path / "participants.csv")
stats1 = pd.read_csv(folder_path / "stats1.csv", low_memory=False)
stats2 = pd.read_csv(folder_path / "stats2.csv", low_memory=False)
stats = pd.concat([stats1, stats2])

# merge into a single DataFrame
a = pd.merge(
    participants, matches, left_on="matchid", right_on="id", suffixes=("", "_matches")
)
allstats_orig = pd.merge(
    a, stats, left_on="matchid", right_on="id", suffixes=("", "_stats")
)
allstats = allstats_orig.copy()

# drop games that lasted less than 10 minutes
allstats = allstats.loc[allstats["duration"] >= 10 * 60, :]

# Convert string-based categories to numeric values
cat_cols = ["role", "position", "version", "platformid"]
for c in cat_cols:
    allstats[c] = allstats[c].astype("category")
    allstats[c] = allstats[c].cat.codes
allstats["wardsbought"] = allstats["wardsbought"].astype(np.int32)

X = allstats.drop(columns=["win"])
y = allstats["win"]

# convert all features we want to consider as rates
rate_features = [
    "kills",
    "deaths",
    "assists",
    "killingsprees",
    "doublekills",
    "triplekills",
    "quadrakills",
    "pentakills",
    "legendarykills",
    "totdmgdealt",
    "magicdmgdealt",
    "physicaldmgdealt",
    "truedmgdealt",
    "totdmgtochamp",
    "magicdmgtochamp",
    "physdmgtochamp",
    "truedmgtochamp",
    "totheal",
    "totunitshealed",
    "dmgtoobj",
    "timecc",
    "totdmgtaken",
    "magicdmgtaken",
    "physdmgtaken",
    "truedmgtaken",
    "goldearned",
    "goldspent",
    "totminionskilled",
    "neutralminionskilled",
    "ownjunglekills",
    "enemyjunglekills",
    "totcctimedealt",
    "pinksbought",
    "wardsbought",
    "wardsplaced",
    "wardskilled",
]
for feature_name in rate_features:
    X[feature_name] /= X["duration"] / 60  # per minute rate

# convert to fraction of game
X["longesttimespentliving"] /= X["duration"]

# define friendly names for the features
full_names = {
    "kills": "Kills per min.",
    "deaths": "Deaths per min.",
    "assists": "Assists per min.",
    "killingsprees": "Killing sprees per min.",
    "longesttimespentliving": "Longest time living as % of game",
    "doublekills": "Double kills per min.",
    "triplekills": "Triple kills per min.",
    "quadrakills": "Quadra kills per min.",
    "pentakills": "Penta kills per min.",
    "legendarykills": "Legendary kills per min.",
    "totdmgdealt": "Total damage dealt per min.",
    "magicdmgdealt": "Magic damage dealt per min.",
    "physicaldmgdealt": "Physical damage dealt per min.",
    "truedmgdealt": "True damage dealt per min.",
    "totdmgtochamp": "Total damage to champions per min.",
    "magicdmgtochamp": "Magic damage to champions per min.",
    "physdmgtochamp": "Physical damage to champions per min.",
    "truedmgtochamp": "True damage to champions per min.",
    "totheal": "Total healing per min.",
    "totunitshealed": "Total units healed per min.",
    "dmgtoobj": "Damage to objects per min.",
    "timecc": "Time spent with crown control per min.",
    "totdmgtaken": "Total damage taken per min.",
    "magicdmgtaken": "Magic damage taken per min.",
    "physdmgtaken": "Physical damage taken per min.",
    "truedmgtaken": "True damage taken per min.",
    "goldearned": "Gold earned per min.",
    "goldspent": "Gold spent per min.",
    "totminionskilled": "Total minions killed per min.",
    "neutralminionskilled": "Neutral minions killed per min.",
    "ownjunglekills": "Own jungle kills per min.",
    "enemyjunglekills": "Enemy jungle kills per min.",
    "totcctimedealt": "Total crown control time dealt per min.",
    "pinksbought": "Pink wards bought per min.",
    "wardsbought": "Wards bought per min.",
    "wardsplaced": "Wards placed per min.",
    "turretkills": "# of turret kills",
    "inhibkills": "# of inhibitor kills",
    "dmgtoturrets": "Damage to turrets",
}
feature_names = [full_names.get(n, n) for n in X.columns]
X.columns = feature_names

# create train/validation split
Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=10)
dt = xgb.DMatrix(Xt, label=yt.values)
dv = xgb.DMatrix(Xv, label=yv.values)

训练 XGBoost 模型

[3]:
params = {
    "objective": "binary:logistic",
    "base_score": np.mean(yt),
    "eval_metric": "logloss",
}
model = xgb.train(
    params,
    dt,
    num_boost_round=10,
    evals=[(dt, "train"), (dv, "valid")],
    early_stopping_rounds=5,
    verbose_eval=25,
)
[0]     train-logloss:0.57255   valid-logloss:0.57258
[9]     train-logloss:0.34293   valid-logloss:0.34323

解释XGBoost模型

因为 Tree SHAP 算法在 XGBoost 中实现,我们可以在数千个样本上快速计算精确的 SHAP 值。单个预测的 SHAP 值(包括最后一列中的预期输出)总和为该预测的模型输出。

[4]:
# compute the SHAP values for every prediction in the validation dataset
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(Xv)

解释一名玩家在特定比赛中获胜的机会

SHAP 值总和为模型预期输出与当前玩家当前输出的差异。请注意,对于 Tree SHAP 实现,解释的是模型的边际输出,而不是转换后的输出(例如逻辑回归的概率)。这意味着该模型的 SHAP 值的单位是 log 赔率比。较大的正值意味着玩家很可能获胜,而较大的负值意味着他们很可能输掉。

[5]:
shap.force_plot(explainer.expected_value, shap_values[0, :], Xv.iloc[0, :])
[5]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[6]:
xs = np.linspace(-4, 4, 100)
pl.xlabel("Log odds of winning")
pl.ylabel("Probability of winning")
pl.title("How changes in log odds convert to probability of winning")
pl.plot(xs, 1 / (1 + np.exp(-xs)))
pl.show()
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_10_0.png

总结所有特征对整个数据集的影响

对于特定预测的一个特征的SHAP值表示当我们观察到该特征时,模型预测的变化程度。在下方的汇总图中,我们将单个特征(如``goldearned``)的所有SHAP值绘制在一行中,其中x轴是SHAP值(对于此模型,单位是获胜的对数几率)。通过对所有特征执行此操作,我们可以看到哪些特征对模型的预测影响很大(如``goldearned``),哪些特征对预测的影响很小(如``kills``)。请注意,当点在线上无法紧密排列时,它们会垂直堆叠以显示密度。每个点还根据该特征的值从高到低进行着色。

[7]:
shap.summary_plot(shap_values, Xv)
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_12_0.png

检查特征变化如何影响模型的预测

我们上面训练的XGBoost模型非常复杂,但通过绘制特征的SHAP值与所有玩家特征的实际值的关系图,我们可以看到特征值的变化如何影响模型的输出。请注意,这些图与标准的部分依赖图非常相似,但它们提供了额外的优势,即显示了上下文对特征的重要性(或者换句话说,交互项的重要性)。交互项对特征重要性的影响由数据点的垂直分散度捕捉。例如,在一场游戏中每分钟仅赚取100金币可能会使某些玩家的获胜对数几率降低10,而对其他玩家仅降低3。为什么会这样?因为这些玩家的其他特征影响了赚取金币对赢得游戏的重要性。请注意,一旦你每分钟赚取至少500金币,垂直分散度就会变窄,这意味着对于高金币赚取者来说,其他特征的上下文影响比低金币赚取者要小。我们用另一个最能解释交互效应方差的特征来给数据点着色。例如,赚取较少金币如果死亡次数不多,影响不大,但如果死亡次数也很多,那就非常糟糕。

下图中的y轴表示该特征的SHAP值,因此-4意味着观察到该特征会降低你获胜的对数几率4,而+2的值意味着观察到该特征会增加你获胜的对数几率2。

请注意,这些图表只是解释了XGBoost模型的工作原理,并不一定反映了现实情况。由于XGBoost模型是基于观测数据训练的,它不一定是一个因果模型,因此仅仅因为改变一个因素使得模型的获胜预测上升,并不总是意味着它会提高你的实际胜算。

[8]:
shap.dependence_plot(
    "Gold earned per min.", shap_values, Xv, interaction_index="Deaths per min."
)
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_14_0.png
[9]:
# sort the features indexes by their importance in the model
# (sum of SHAP value magnitudes over the validation dataset)
top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))

# make SHAP plots of the three most important features
for i in range(20):
    shap.dependence_plot(top_inds[i], shap_values, Xv)
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_0.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_1.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_2.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_3.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_4.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_5.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_6.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_7.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_8.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_9.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_10.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_11.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_12.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_13.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_14.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_15.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_16.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_17.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_18.png
../../../_images/example_notebooks_tabular_examples_tree_based_models_League_of_Legends_Win_Prediction_with_XGBoost_15_19.png