.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/cluster/plot_cluster_comparison.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_cluster_plot_cluster_comparison.py: ========================================================= 在玩具数据集上比较不同的聚类算法 ========================================================= 此示例展示了不同聚类算法在“有趣”但仍然是二维的数据集上的特征。除了最后一个数据集外,这些数据集-算法对的参数都经过调整以产生良好的聚类结果。一些算法对参数值比其他算法更敏感。 最后一个数据集是聚类的“无效”情况的示例:数据是同质的,没有好的聚类结果。对于此示例,无效数据集使用与其上方一行数据集相同的参数,这代表了参数值与数据结构的不匹配。 虽然这些示例提供了一些关于算法的直觉,但这种直觉可能不适用于非常高维的数据。 .. GENERATED FROM PYTHON SOURCE LINES 13-271 .. image-sg:: /auto_examples/cluster/images/sphx_glr_plot_cluster_comparison_001.png :alt: MiniBatch KMeans, Affinity Propagation, MeanShift, Spectral Clustering, Ward, Agglomerative Clustering, DBSCAN, HDBSCAN, OPTICS, BIRCH, Gaussian Mixture :srcset: /auto_examples/cluster/images/sphx_glr_plot_cluster_comparison_001.png :class: sphx-glr-single-img .. code-block:: Python import time import warnings from itertools import cycle, islice import matplotlib.pyplot as plt import numpy as np from sklearn import cluster, datasets, mixture from sklearn.neighbors import kneighbors_graph from sklearn.preprocessing import StandardScaler # =========== # 生成数据集。我们选择足够大的规模来观察算法的可扩展性,但不会太大以避免运行时间过长。 # =========== n_samples = 500 seed = 30 noisy_circles = datasets.make_circles( n_samples=n_samples, factor=0.5, noise=0.05, random_state=seed ) noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05, random_state=seed) blobs = datasets.make_blobs(n_samples=n_samples, random_state=seed) rng = np.random.RandomState(seed) no_structure = rng.rand(n_samples, 2), None # 各向异性分布的数据 random_state = 170 X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state) transformation = [[0.6, -0.6], [-0.4, 0.8]] X_aniso = np.dot(X, transformation) aniso = (X_aniso, y) # 方差不同的斑点 varied = datasets.make_blobs( n_samples=n_samples, cluster_std=[1.0, 2.5, 0.5], random_state=random_state ) # ============ # 设置集群参数 # ============ plt.figure(figsize=(9 * 2 + 3, 13)) plt.subplots_adjust( left=0.02, right=0.98, bottom=0.001, top=0.95, wspace=0.05, hspace=0.01 ) plot_num = 1 default_base = { "quantile": 0.3, "eps": 0.3, "damping": 0.9, "preference": -200, "n_neighbors": 3, "n_clusters": 3, "min_samples": 7, "xi": 0.05, "min_cluster_size": 0.1, "allow_single_cluster": True, "hdbscan_min_cluster_size": 15, "hdbscan_min_samples": 3, "random_state": 42, } datasets = [ ( noisy_circles, { "damping": 0.77, "preference": -240, "quantile": 0.2, "n_clusters": 2, "min_samples": 7, "xi": 0.08, }, ), ( noisy_moons, { "damping": 0.75, "preference": -220, "n_clusters": 2, "min_samples": 7, "xi": 0.1, }, ), ( varied, { "eps": 0.18, "n_neighbors": 2, "min_samples": 7, "xi": 0.01, "min_cluster_size": 0.2, }, ), ( aniso, { "eps": 0.15, "n_neighbors": 2, "min_samples": 7, "xi": 0.1, "min_cluster_size": 0.2, }, ), (blobs, {"min_samples": 7, "xi": 0.1, "min_cluster_size": 0.2}), (no_structure, {}), ] for i_dataset, (dataset, algo_params) in enumerate(datasets): # 使用数据集特定的值更新参数 params = default_base.copy() params.update(algo_params) X, y = dataset # 规范化数据集以便于参数选择 X = StandardScaler().fit_transform(X) # 估计均值漂移的带宽 bandwidth = cluster.estimate_bandwidth(X, quantile=params["quantile"]) # 结构化Ward的连接矩阵 connectivity = kneighbors_graph( X, n_neighbors=params["n_neighbors"], include_self=False ) # 使连接对称 connectivity = 0.5 * (connectivity + connectivity.T) # ============ # 创建集群对象 # ============ ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True) two_means = cluster.MiniBatchKMeans( n_clusters=params["n_clusters"], random_state=params["random_state"], ) ward = cluster.AgglomerativeClustering( n_clusters=params["n_clusters"], linkage="ward", connectivity=connectivity ) spectral = cluster.SpectralClustering( n_clusters=params["n_clusters"], eigen_solver="arpack", affinity="nearest_neighbors", random_state=params["random_state"], ) dbscan = cluster.DBSCAN(eps=params["eps"]) hdbscan = cluster.HDBSCAN( min_samples=params["hdbscan_min_samples"], min_cluster_size=params["hdbscan_min_cluster_size"], allow_single_cluster=params["allow_single_cluster"], ) optics = cluster.OPTICS( min_samples=params["min_samples"], xi=params["xi"], min_cluster_size=params["min_cluster_size"], ) affinity_propagation = cluster.AffinityPropagation( damping=params["damping"], preference=params["preference"], random_state=params["random_state"], ) average_linkage = cluster.AgglomerativeClustering( linkage="average", metric="cityblock", n_clusters=params["n_clusters"], connectivity=connectivity, ) birch = cluster.Birch(n_clusters=params["n_clusters"]) gmm = mixture.GaussianMixture( n_components=params["n_clusters"], covariance_type="full", random_state=params["random_state"], ) clustering_algorithms = ( ("MiniBatch\nKMeans", two_means), ("Affinity\nPropagation", affinity_propagation), ("MeanShift", ms), ("Spectral\nClustering", spectral), ("Ward", ward), ("Agglomerative\nClustering", average_linkage), ("DBSCAN", dbscan), ("HDBSCAN", hdbscan), ("OPTICS", optics), ("BIRCH", birch), ("Gaussian\nMixture", gmm), ) for name, algorithm in clustering_algorithms: t0 = time.time() # 捕捉与 kneighbors_graph 相关的警告 with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="the number of connected components of the " + "connectivity matrix is [0-9]{1,2}" + " > 1. Completing it to avoid stopping the tree early.", category=UserWarning, ) warnings.filterwarnings( "ignore", message="Graph is not fully connected, spectral embedding" + " may not work as expected.", category=UserWarning, ) algorithm.fit(X) t1 = time.time() if hasattr(algorithm, "labels_"): y_pred = algorithm.labels_.astype(int) else: y_pred = algorithm.predict(X) plt.subplot(len(datasets), len(clustering_algorithms), plot_num) if i_dataset == 0: plt.title(name, size=18) colors = np.array( list( islice( cycle( [ "#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3", "#999999", "#e41a1c", "#dede00", ] ), int(max(y_pred) + 1), ) ) ) # 为异常值添加黑色(如果有的话) colors = np.append(colors, ["#000000"]) plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred]) plt.xlim(-2.5, 2.5) plt.ylim(-2.5, 2.5) plt.xticks(()) plt.yticks(()) plt.text( 0.99, 0.01, ("%.2fs" % (t1 - t0)).lstrip("0"), transform=plt.gca().transAxes, size=15, horizontalalignment="right", ) plot_num += 1 plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.540 seconds) .. _sphx_glr_download_auto_examples_cluster_plot_cluster_comparison.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/cluster/plot_cluster_comparison.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cluster_comparison.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cluster_comparison.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_cluster_comparison.zip ` .. include:: plot_cluster_comparison.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_