.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/inspection/plot_permutation_importance.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_inspection_plot_permutation_importance.py: ================================================ 置换重要性与随机森林特征重要性(MDI)对比 ================================================ 在这个例子中,我们将比较基于不纯度的 :class:`~sklearn.ensemble.RandomForestClassifier` 特征重要性与 在泰坦尼克号数据集上使用 :func:`~sklearn.inspection.permutation_importance` 的置换重要性。我们将展示 基于不纯度的特征重要性可能会夸大数值特征的重要性。 此外,随机森林的基于不纯度的特征重要性存在一个问题,即它是基于训练数据集的统计数据计算的:即使是对目标变量没有预测能力的特征,只要模型有能力利用它们进行过拟合,其重要性也可能很高。 这个例子展示了如何使用置换重要性作为一种替代方法来缓解这些限制。 .. rubric:: 参考文献 * :doi:`L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. <10.1023/A:1010933404324>` .. GENERATED FROM PYTHON SOURCE LINES 24-32 数据加载和特征工程 -------------------- 让我们使用pandas加载一份泰坦尼克号数据集的副本。以下展示了如何对数值特征和类别特征分别进行预处理。 我们进一步包括了两个与目标变量( ``survived`` )没有任何相关性的随机变量: - ``random_num`` 是一个高基数数值变量(具有与记录数相同的唯一值)。 - ``random_cat`` 是一个低基数分类变量(3个可能的值)。 .. GENERATED FROM PYTHON SOURCE LINES 32-48 .. code-block:: Python import numpy as np from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) rng = np.random.RandomState(seed=42) X["random_cat"] = rng.randint(3, size=X.shape[0]) X["random_num"] = rng.randn(X.shape[0]) categorical_columns = ["pclass", "sex", "embarked", "random_cat"] numerical_columns = ["age", "sibsp", "parch", "fare", "random_num"] X = X[categorical_columns + numerical_columns] X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) .. GENERATED FROM PYTHON SOURCE LINES 49-53 我们基于随机森林定义了一个预测模型。因此,我们将进行以下预处理步骤: - 使用 :class:`~sklearn.preprocessing.OrdinalEncoder` 对分类特征进行编码; - 使用 :class:`~sklearn.impute.SimpleImputer` 对数值特征的缺失值进行填充,采用均值策略。 .. GENERATED FROM PYTHON SOURCE LINES 53-80 .. code-block:: Python from sklearn.compose import ColumnTransformer from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline from sklearn.preprocessing import OrdinalEncoder categorical_encoder = OrdinalEncoder( handle_unknown="use_encoded_value", unknown_value=-1, encoded_missing_value=-1 ) numerical_pipe = SimpleImputer(strategy="mean") preprocessing = ColumnTransformer( [ ("cat", categorical_encoder, categorical_columns), ("num", numerical_pipe, numerical_columns), ], verbose_feature_names_out=False, ) rf = Pipeline( [ ("preprocess", preprocessing), ("classifier", RandomForestClassifier(random_state=42)), ] ) rf.fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('preprocess',
                     ColumnTransformer(transformers=[('cat',
                                                      OrdinalEncoder(encoded_missing_value=-1,
                                                                     handle_unknown='use_encoded_value',
                                                                     unknown_value=-1),
                                                      ['pclass', 'sex', 'embarked',
                                                       'random_cat']),
                                                     ('num', SimpleImputer(),
                                                      ['age', 'sibsp', 'parch',
                                                       'fare', 'random_num'])],
                                       verbose_feature_names_out=False)),
                    ('classifier', RandomForestClassifier(random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 81-90 模型的准确性 --------------------- 在检查特征重要性之前,重要的是要确认模型的预测性能足够高。实际上,检查一个没有预测能力的模型的重要特征是没有意义的。 在这里可以观察到,训练准确率非常高(森林模型有足够的容量来完全记住训练集),但由于随机森林的内置袋装法,它仍然可以很好地泛化到测试集。 通过限制树的容量(例如通过设置 ``min_samples_leaf=5`` 或 ``min_samples_leaf=10`` ),可能会在训练集上牺牲一些准确性,以换取测试集上略微更好的准确性,从而限制过拟合,同时不过多引入欠拟合。 不过,现在我们还是保留高容量的随机森林模型,以便说明在具有许多唯一值的变量上使用特征重要性的一些陷阱。 .. GENERATED FROM PYTHON SOURCE LINES 90-94 .. code-block:: Python print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}") print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none RF train accuracy: 1.000 RF test accuracy: 0.814 .. GENERATED FROM PYTHON SOURCE LINES 95-107 树的特征重要性来自于杂质减少均值(MDI) ------------------------------------------------ 基于杂质的特征重要性将数值特征排为最重要的特征。因此,非预测性的 ``random_num`` 变量被排为最重要的特征之一! 这个问题源于基于不纯度的特征重要性方法的两个局限性: - 基于不纯度的重要性偏向于高基数特征; - 基于不纯度的重要性是根据训练集统计数据计算的,因此不能反映特征在生成对测试集的预测时的有用性(当模型有足够的容量时)。 对高基数特征的偏向解释了为什么 `random_num` 相对于 `random_cat` 具有非常大的重要性,而我们本来预期这两个随机特征都应该具有零重要性。 我们使用训练集统计数据的事实解释了为什么 `random_num` 和 `random_cat` 特征具有非空的重要性。 .. GENERATED FROM PYTHON SOURCE LINES 107-115 .. code-block:: Python import pandas as pd feature_names = rf[:-1].get_feature_names_out() mdi_importances = pd.Series( rf[-1].feature_importances_, index=feature_names ).sort_values(ascending=True) .. GENERATED FROM PYTHON SOURCE LINES 116-120 .. code-block:: Python ax = mdi_importances.plot.barh() ax.set_title("Random Forest Feature Importances (MDI)") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_001.png :alt: Random Forest Feature Importances (MDI) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 121-124 作为替代方案, ``rf`` 的排列重要性是在保留的测试集上计算的。这表明低基数的分类特征 `sex` 和 `pclass` 是最重要的特征。实际上,排列这些特征的值将导致模型在测试集上的准确性评分最大幅度的下降。 还要注意的是,两个随机特征的重要性都非常低(接近于0),这是预料之中的。 .. GENERATED FROM PYTHON SOURCE LINES 124-141 .. code-block:: Python from sklearn.inspection import permutation_importance result = permutation_importance( rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = result.importances_mean.argsort() importances = pd.DataFrame( result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) ax = importances.plot.box(vert=False, whis=10) ax.set_title("Permutation Importances (test set)") ax.axvline(x=0, color="k", linestyle="--") ax.set_xlabel("Decrease in accuracy score") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_002.png :alt: Permutation Importances (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 142-143 也可以在训练集上计算排列重要性。这表明 `random_num` 和 `random_cat` 的重要性排名显著高于在测试集上计算的排名。这两个图之间的差异证实了随机森林模型有足够的能力利用这些随机数值和分类特征进行过拟合。 .. GENERATED FROM PYTHON SOURCE LINES 143-159 .. code-block:: Python result = permutation_importance( rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = result.importances_mean.argsort() importances = pd.DataFrame( result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) ax = importances.plot.box(vert=False, whis=10) ax.set_title("Permutation Importances (train set)") ax.axvline(x=0, color="k", linestyle="--") ax.set_xlabel("Decrease in accuracy score") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_003.png :alt: Permutation Importances (train set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 160-161 我们可以通过将 `min_samples_leaf` 设置为 20 个数据点来限制树的容量,从而进一步重试实验以防止过拟合。 .. GENERATED FROM PYTHON SOURCE LINES 161-164 .. code-block:: Python rf.set_params(classifier__min_samples_leaf=20).fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('preprocess',
                     ColumnTransformer(transformers=[('cat',
                                                      OrdinalEncoder(encoded_missing_value=-1,
                                                                     handle_unknown='use_encoded_value',
                                                                     unknown_value=-1),
                                                      ['pclass', 'sex', 'embarked',
                                                       'random_cat']),
                                                     ('num', SimpleImputer(),
                                                      ['age', 'sibsp', 'parch',
                                                       'fare', 'random_num'])],
                                       verbose_feature_names_out=False)),
                    ('classifier',
                     RandomForestClassifier(min_samples_leaf=20, random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 165-166 观察训练集和测试集上的准确率得分,我们发现这两个指标现在非常相似。因此,我们的模型不再过拟合。然后,我们可以使用这个新模型检查排列重要性。 .. GENERATED FROM PYTHON SOURCE LINES 166-170 .. code-block:: Python print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}") print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none RF train accuracy: 0.810 RF test accuracy: 0.832 .. GENERATED FROM PYTHON SOURCE LINES 171-179 .. code-block:: Python train_result = permutation_importance( rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2 ) test_results = permutation_importance( rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = train_result.importances_mean.argsort() .. GENERATED FROM PYTHON SOURCE LINES 180-189 .. code-block:: Python train_importances = pd.DataFrame( train_result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) test_importances = pd.DataFrame( test_results.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) .. GENERATED FROM PYTHON SOURCE LINES 190-197 .. code-block:: Python for name, importances in zip(["train", "test"], [train_importances, test_importances]): ax = importances.plot.box(vert=False, whis=10) ax.set_title(f"Permutation Importances ({name} set)") ax.set_xlabel("Decrease in accuracy score") ax.axvline(x=0, color="k", linestyle="--") ax.figure.tight_layout() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_004.png :alt: Permutation Importances (train set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_005.png :alt: Permutation Importances (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 198-199 现在,我们可以观察到在这两个集合上, `random_num` 和 `random_cat` 特征的重要性相比于过拟合的随机森林要低。然而,关于其他特征重要性的结论仍然有效。 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.420 seconds) .. _sphx_glr_download_auto_examples_inspection_plot_permutation_importance.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/inspection/plot_permutation_importance.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_permutation_importance.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_permutation_importance.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_permutation_importance.zip ` .. include:: plot_permutation_importance.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_