Kmeans: k-均值聚类

k-means聚类的实现。

> 来自 mlxtend.cluster 的 Kmeans

概述

聚类属于无监督学习的范畴,这是机器学习的一个子领域,在实际应用中我们无法获取真实的标签。在聚类中,我们的目标是根据相似性对样本进行分组(在k均值中:欧几里得距离)。

k均值算法可以总结如下:

  1. 随机选择k个质心作为初始聚类中心。
  2. 将每个样本分配给最近的质心 $\mu(j), \; j \in {1,...,k}$。
  3. 将质心移动到被分配给它的样本的中心。
  4. 重复步骤2和3,直到聚类分配不再变化,或者达到用户定义的容忍度或最大迭代次数。

参考文献

示例 1 - 三个斑点

加载一些示例数据:

import matplotlib.pyplot as plt
from mlxtend.data import three_blobs_data

X, y = three_blobs_data()
plt.scatter(X[:, 0], X[:, 1], c='white')
plt.show()

png

计算聚类中心:

from mlxtend.cluster import Kmeans

km = Kmeans(k=3, 
            max_iter=50, 
            random_seed=1, 
            print_progress=3)

km.fit(X)

print('Iterations until convergence:', km.iterations_)
print('Final centroids:\n', km.centroids_)

Iteration: 2/50 | Elapsed: 00:00:00 | ETA: 00:00:00

Iterations until convergence: 2
Final centroids:
 [[-1.5947298   2.92236966]
 [ 2.06521743  0.96137409]
 [ 0.9329651   4.35420713]]

可视化聚类成员资格:

y_clust = km.predict(X)

plt.scatter(X[y_clust == 0, 0],
            X[y_clust == 0, 1],
            s=50,
            c='lightgreen',
            marker='s',
            label='cluster 1')

plt.scatter(X[y_clust == 1,0],
            X[y_clust == 1,1],
            s=50,
            c='orange',
            marker='o',
            label='cluster 2')

plt.scatter(X[y_clust == 2,0],
            X[y_clust == 2,1],
            s=50,
            c='lightblue',
            marker='v',
            label='cluster 3')


plt.scatter(km.centroids_[:,0],
            km.centroids_[:,1],
            s=250,
            marker='*',
            c='red',
            label='centroids')

plt.legend(loc='lower left',
           scatterpoints=1)
plt.grid()
plt.show()

png

API

Kmeans(k, max_iter=10, convergence_tolerance=1e-05, random_seed=None, print_progress=0)

K-means clustering class.

Added in 0.4.1dev

Parameters

Attributes

Examples

For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/classifier/Kmeans/

Methods


fit(X, init_params=True)

Learn model from training data.

Parameters

Returns


predict(X)

Predict targets from X.

Parameters

Returns