.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/neighbors/plot_digits_kde_sampling.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_neighbors_plot_digits_kde_sampling.py: ========================= 核密度估计 ========================= 这个例子展示了如何使用核密度估计(KDE),一种强大的非参数密度估计技术,来学习数据集的生成模型。有了这个生成模型,就可以绘制新的样本。这些新样本反映了数据的潜在模型。 .. GENERATED FROM PYTHON SOURCE LINES 9-63 .. image-sg:: /auto_examples/neighbors/images/sphx_glr_plot_digits_kde_sampling_001.png :alt: Selection from the input data, "New" digits drawn from the kernel density model :srcset: /auto_examples/neighbors/images/sphx_glr_plot_digits_kde_sampling_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none best bandwidth: 3.79269019073225 | .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import load_digits from sklearn.decomposition import PCA from sklearn.model_selection import GridSearchCV from sklearn.neighbors import KernelDensity # 加载数据 # # digits = load_digits() # 将64维数据投影到更低维度 pca = PCA(n_components=15, whiten=False) data = pca.fit_transform(digits.data) # 使用网格搜索交叉验证来优化带宽 params = {"bandwidth": np.logspace(-1, 1, 20)} grid = GridSearchCV(KernelDensity(), params) grid.fit(data) print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth)) # 使用最佳估计器计算核密度估计 kde = grid.best_estimator_ # 从数据中抽取44个新点 new_data = kde.sample(44, random_state=0) new_data = pca.inverse_transform(new_data) # 将数据转换为4x11的网格 new_data = new_data.reshape((4, 11, -1)) real_data = digits.data[:44].reshape((4, 11, -1)) # 绘制真实数字和重采样数字 fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[])) for j in range(11): ax[4, j].set_visible(False) for i in range(4): im = ax[i, j].imshow( real_data[i, j].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" ) im.set_clim(0, 16) im = ax[i + 5, j].imshow( new_data[i, j].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" ) im.set_clim(0, 16) ax[0, 5].set_title("Selection from the input data") ax[5, 5].set_title('"New" digits drawn from the kernel density model') plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.740 seconds) .. _sphx_glr_download_auto_examples_neighbors_plot_digits_kde_sampling.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/neighbors/plot_digits_kde_sampling.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_digits_kde_sampling.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_digits_kde_sampling.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_digits_kde_sampling.zip ` .. include:: plot_digits_kde_sampling.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_