.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/mixture/plot_gmm_selection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_mixture_plot_gmm_selection.py: ================================ 高斯混合模型选择 ================================ 本示例展示了如何使用高斯混合模型(GMM)和 :ref:`信息论标准 ` 进行模型选择。模型选择涉及模型中的协方差类型和成分数量。 在这种情况下,赤池信息准则(AIC)和贝叶斯信息准则(BIC)都能提供正确的结果,但我们只演示后者,因为BIC更适合在一组候选模型中识别真实模型。与贝叶斯方法不同,这种推断是无先验的。 .. GENERATED FROM PYTHON SOURCE LINES 13-17 数据生成 --------------- 我们通过随机采样 `numpy.random.randn` 返回的标准正态分布生成两个组件(每个组件包含 `n_samples` )。一个组件保持球形但进行了平移和重新缩放。另一个组件被变形以具有更一般的协方差矩阵。 .. GENERATED FROM PYTHON SOURCE LINES 17-28 .. code-block:: Python import numpy as np n_samples = 500 np.random.seed(0) C = np.array([[0.0, -0.1], [1.7, 0.4]]) component_1 = np.dot(np.random.randn(n_samples, 2), C) # general component_2 = 0.7 * np.random.randn(n_samples, 2) + np.array([-4, 1]) # spherical X = np.concatenate([component_1, component_2]) .. GENERATED FROM PYTHON SOURCE LINES 29-30 我们可以将不同的组件可视化: .. GENERATED FROM PYTHON SOURCE LINES 30-40 .. code-block:: Python import matplotlib.pyplot as plt plt.scatter(component_1[:, 0], component_1[:, 1], s=0.8) plt.scatter(component_2[:, 0], component_2[:, 1], s=0.8) plt.title("Gaussian Mixture components") plt.axis("equal") plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png :alt: Gaussian Mixture components :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 41-54 模型训练与选择 ---------------------------- 我们将成分数量从1到6进行变化,并使用不同类型的协方差参数: - `"full"` :每个成分都有自己的通用协方差矩阵。 - `"tied"` :所有成分共享同一个通用协方差矩阵。 - `"diag"` :每个成分都有自己的对角协方差矩阵。 - `"spherical"` :每个成分都有自己的单一方差。 我们对不同的模型进行评分,并保留最佳模型(即BIC最低的模型)。这是通过使用 :class:`~sklearn.model_selection.GridSearchCV` 和一个用户定义的评分函数来实现的,该函数返回负的BIC分数,因为 :class:`~sklearn.model_selection.GridSearchCV` 旨在 **最大化** 一个分数(最大化负的BIC等同于最小化BIC)。 最佳参数集和估计器分别存储在 `best_parameters_` 和 `best_estimator_` 中。 .. GENERATED FROM PYTHON SOURCE LINES 54-74 .. code-block:: Python from sklearn.mixture import GaussianMixture from sklearn.model_selection import GridSearchCV def gmm_bic_score(estimator, X): """可传递给GridSearchCV以使用BIC评分的可调用对象。""" # 由于 GridSearchCV 期望最大化得分,因此将其取负。 return -estimator.bic(X) param_grid = { "n_components": range(1, 7), "covariance_type": ["spherical", "tied", "diag", "full"], } grid_search = GridSearchCV( GaussianMixture(), param_grid=param_grid, scoring=gmm_bic_score ) grid_search.fit(X) .. raw:: html
GridSearchCV(estimator=GaussianMixture(),
                 param_grid={'covariance_type': ['spherical', 'tied', 'diag',
                                                 'full'],
                             'n_components': range(1, 7)},
                 scoring=<function gmm_bic_score at 0xffff4c10af20>)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 75-79 绘制BIC分数 ------------------- 为了简化绘图,我们可以从网格搜索交叉验证的结果中创建一个 `pandas.DataFrame` 。我们重新反转 BIC 分数的符号,以显示最小化它的效果。 .. GENERATED FROM PYTHON SOURCE LINES 79-95 .. code-block:: Python import pandas as pd df = pd.DataFrame(grid_search.cv_results_)[ ["param_n_components", "param_covariance_type", "mean_test_score"] ] df["mean_test_score"] = -df["mean_test_score"] df = df.rename( columns={ "param_n_components": "Number of components", "param_covariance_type": "Type of covariance", "mean_test_score": "BIC score", } ) df.sort_values(by="BIC score").head() .. raw:: html
Number of components Type of covariance BIC score
19 2 full 1046.829429
20 3 full 1084.038689
21 4 full 1114.517272
22 5 full 1148.512281
23 6 full 1179.977890


.. GENERATED FROM PYTHON SOURCE LINES 96-107 .. code-block:: Python import seaborn as sns sns.catplot( data=df, kind="bar", x="Number of components", y="BIC score", hue="Type of covariance", ) plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png :alt: plot gmm selection :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 108-119 在当前情况下,具有2个成分和完整协方差的模型(对应于真实生成模型)具有最低的BIC分数,因此被网格搜索选中。 绘制最佳模型 ------------------- 我们绘制一个椭圆来展示所选模型的每个高斯分量。为此,需要找到由 `covariances_` 属性返回的协方差矩阵的特征值。这些矩阵的形状取决于 `covariance_type` : - `"full"` :( `n_components` , `n_features` , `n_features` ) - `"tied"` :( `n_features` , `n_features` ) - `"diag"` :( `n_components` , `n_features` ) - `"spherical"` :( `n_components` ,) .. GENERATED FROM PYTHON SOURCE LINES 119-154 .. code-block:: Python from matplotlib.patches import Ellipse from scipy import linalg color_iter = sns.color_palette("tab10", 2)[::-1] Y_ = grid_search.predict(X) fig, ax = plt.subplots() for i, (mean, cov, color) in enumerate( zip( grid_search.best_estimator_.means_, grid_search.best_estimator_.covariances_, color_iter, ) ): v, w = linalg.eigh(cov) if not np.any(Y_ == i): continue plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], 0.8, color=color) angle = np.arctan2(w[0][1], w[0][0]) angle = 180.0 * angle / np.pi # convert to degrees v = 2.0 * np.sqrt(2.0) * np.sqrt(v) ellipse = Ellipse(mean, v[0], v[1], angle=180.0 + angle, color=color) ellipse.set_clip_box(fig.bbox) ellipse.set_alpha(0.5) ax.add_artist(ellipse) plt.title( f"Selected GMM: {grid_search.best_params_['covariance_type']} model, " f"{grid_search.best_params_['n_components']} components" ) plt.axis("equal") plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png :alt: Selected GMM: full model, 2 components :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.737 seconds) .. _sphx_glr_download_auto_examples_mixture_plot_gmm_selection.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/mixture/plot_gmm_selection.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gmm_selection.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gmm_selection.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gmm_selection.zip ` .. include:: plot_gmm_selection.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_