scipy.cluster.vq.

kmeans2#

scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, seed=None)[源代码][源代码]#

使用k-means算法将一组观测值分类为k个簇。

该算法试图最小化观测值与质心之间的欧几里得距离。包含了几种初始化方法。

参数:
数据ndarray

一个 ‘M’ 行 ‘N’ 列的数组,表示 ‘M’ 个观测值在 ‘N’ 维空间中,或一个长度为 ‘M’ 的数组,表示 ‘M’ 个一维观测值。

kint 或 ndarray

要形成的簇的数量以及要生成的质心的数量。如果 minit 初始化字符串是 ‘matrix’,或者如果给出了一个 ndarray 代替,它将被解释为要使用的初始簇。

iterint, 可选

k-means 算法运行的迭代次数。注意,这与 kmeans 函数的 iters 参数在含义上有所不同。

threshfloat, 可选

(尚未使用)

minitstr, 可选

初始化方法。可用方法有 ‘random’, ‘points’, ‘++’ 和 ‘matrix’:

‘random’: 从数据估计的均值和方差的高斯分布中生成k个中心点。

‘points’: 从数据中随机选择 k 个观测值(行)作为初始中心点。

‘++’: 根据 kmeans++ 方法选择 k 个观测值(仔细播种)

‘matrix’: 将 k 参数解释为初始质心的 k 乘 M 数组(或 1-D 数据的 k 长度数组)。

缺失str, 可选

处理空簇的方法。可用方法有 ‘warn’ 和 ‘raise’:

‘warn’: 发出警告并继续。

‘raise’: 引发一个 ClusterError 并终止算法。

check_finitebool, 可选

是否检查输入矩阵是否仅包含有限数值。禁用可能会提高性能,但如果输入包含无穷大或NaN,可能会导致问题(崩溃、非终止)。默认值:True

seed : {None, int, numpy.random.Generator, numpy.random.RandomState}, 可选{None, int,}

用于初始化伪随机数生成器的种子。如果 seed 为 None(或 numpy.random),则使用 numpy.random.RandomState 单例。如果 seed 是整数,则使用新的 RandomState 实例,并以 seed 为种子。如果 seed 已经是 GeneratorRandomState 实例,则使用该实例。默认值为 None。

返回:
质心ndarray

一个在 k-means 最后一次迭代中找到的 ‘k’ 乘 ‘N’ 的质心数组。

标签ndarray

label[i] 是第 i 个观测值最接近的质心的代码或索引。

参见

kmeans

参考文献

[1]

D. Arthur and S. Vassilvitskii, “k-means++: the advantages of careful seeding”, Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, 2007.

示例

>>> from scipy.cluster.vq import kmeans2
>>> import matplotlib.pyplot as plt
>>> import numpy as np

创建 z,一个形状为 (100, 2) 的数组,包含来自三个多元正态分布的样本混合。

>>> rng = np.random.default_rng()
>>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
>>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
>>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
>>> z = np.concatenate((a, b, c))
>>> rng.shuffle(z)

计算三个聚类。

>>> centroid, label = kmeans2(z, 3, minit='points')
>>> centroid
array([[ 2.22274463, -0.61666946],  # may vary
       [ 0.54069047,  5.86541444],
       [ 6.73846769,  4.01991898]])

每个簇中有多少个点?

>>> counts = np.bincount(label)
>>> counts
array([29, 51, 20])  # may vary

绘制聚类图。

>>> w0 = z[label == 0]
>>> w1 = z[label == 1]
>>> w2 = z[label == 2]
>>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
>>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
>>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
>>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
>>> plt.axis('equal')
>>> plt.legend(shadow=True)
>>> plt.show()
../../_images/scipy-cluster-vq-kmeans2-1.png