kmeans2#
- scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, seed=None)[源代码][源代码]#
使用k-means算法将一组观测值分类为k个簇。
该算法试图最小化观测值与质心之间的欧几里得距离。包含了几种初始化方法。
- 参数:
- 数据ndarray
一个 ‘M’ 行 ‘N’ 列的数组,表示 ‘M’ 个观测值在 ‘N’ 维空间中,或一个长度为 ‘M’ 的数组,表示 ‘M’ 个一维观测值。
- kint 或 ndarray
要形成的簇的数量以及要生成的质心的数量。如果 minit 初始化字符串是 ‘matrix’,或者如果给出了一个 ndarray 代替,它将被解释为要使用的初始簇。
- iterint, 可选
k-means 算法运行的迭代次数。注意,这与 kmeans 函数的 iters 参数在含义上有所不同。
- threshfloat, 可选
(尚未使用)
- minitstr, 可选
初始化方法。可用方法有 ‘random’, ‘points’, ‘++’ 和 ‘matrix’:
‘random’: 从数据估计的均值和方差的高斯分布中生成k个中心点。
‘points’: 从数据中随机选择 k 个观测值(行)作为初始中心点。
‘++’: 根据 kmeans++ 方法选择 k 个观测值(仔细播种)
‘matrix’: 将 k 参数解释为初始质心的 k 乘 M 数组(或 1-D 数据的 k 长度数组)。
- 缺失str, 可选
处理空簇的方法。可用方法有 ‘warn’ 和 ‘raise’:
‘warn’: 发出警告并继续。
‘raise’: 引发一个 ClusterError 并终止算法。
- check_finitebool, 可选
是否检查输入矩阵是否仅包含有限数值。禁用可能会提高性能,但如果输入包含无穷大或NaN,可能会导致问题(崩溃、非终止)。默认值:True
- seed : {None, int,
numpy.random.Generator
,numpy.random.RandomState
}, 可选{None, int,} 用于初始化伪随机数生成器的种子。如果 seed 为 None(或
numpy.random
),则使用numpy.random.RandomState
单例。如果 seed 是整数,则使用新的RandomState
实例,并以 seed 为种子。如果 seed 已经是Generator
或RandomState
实例,则使用该实例。默认值为 None。
- 返回:
- 质心ndarray
一个在 k-means 最后一次迭代中找到的 ‘k’ 乘 ‘N’ 的质心数组。
- 标签ndarray
label[i] 是第 i 个观测值最接近的质心的代码或索引。
参见
参考文献
[1]D. Arthur and S. Vassilvitskii, “k-means++: the advantages of careful seeding”, Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, 2007.
示例
>>> from scipy.cluster.vq import kmeans2 >>> import matplotlib.pyplot as plt >>> import numpy as np
创建 z,一个形状为 (100, 2) 的数组,包含来自三个多元正态分布的样本混合。
>>> rng = np.random.default_rng() >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45) >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30) >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25) >>> z = np.concatenate((a, b, c)) >>> rng.shuffle(z)
计算三个聚类。
>>> centroid, label = kmeans2(z, 3, minit='points') >>> centroid array([[ 2.22274463, -0.61666946], # may vary [ 0.54069047, 5.86541444], [ 6.73846769, 4.01991898]])
每个簇中有多少个点?
>>> counts = np.bincount(label) >>> counts array([29, 51, 20]) # may vary
绘制聚类图。
>>> w0 = z[label == 0] >>> w1 = z[label == 1] >>> w2 = z[label == 2] >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0') >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1') >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2') >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids') >>> plt.axis('equal') >>> plt.legend(shadow=True) >>> plt.show()