• 05. 聚类---K(k-means)均值


     一、概念    

     K均值(k-means)是聚类算法中最为简单和高效的算法,属于无监督的算法。

          核心思想:由用户指定K个初始质心(initial centroids),以作为聚类的类别(cluster),重复迭代直至算法收敛 

    基本算法流程:

         1.选取K个初始质心(作为初始的cluster)

         2.repeat:

                  对每个样本点,计算得到距其最近的质心,将其类别标为该质心所对应的cluster

                  重新计算K个cluser对应的质心

          3.until 质心不在发生变化或迭代达到上限

    二、python简单方法讲解

    dist = np.array([[121,34,43,32],
                     [121,221,12,23],
                     [65,21,2,43],
                     [1,221,32,43],
                    [21,0,2,3]])
    c_index = np.argmin(dist)
    print(c_index)
     ##输出17,把所有的二维数据当一维数据做处理,显示出最小的索引,0所在的位置在第17索引上


    c_index = np.argmin(dist,axis=1)
    print(c_index)
    ##输出 [3 2 2 0 1] axis将二维数据求最小值当一列处理,返回的是每行的最小值索引
    print(c_index==2)
    # [False  True  True False False]

    x_new=np.array(
    [[-0.02708305  5.0215929 ],
     [-5.49252256  6.27366991],
     [-5.37691608  1.51403209],
     [-5.37872006  2.16059225],
     [ 9.58333171  8.10916554]])


    x_new[c_index==2]
    #array([[-5.49252256,  6.27366991],
    #       [-5.37691608,  1.51403209]])

    np.mean(x_new[c_index==2],axis=0)
    #输出 array([-5.43471932, 3.893851 ]) 列加起来求平均

    三、python实现kmeans

    1. ### 0. 引入依赖
    2. import numpy as np
    3. import matplotlib.pyplot as plt
    4. # 从sklearn 中直接生成聚类数据
    5. from sklearn.datasets import make_blobs
    6. ### 1.数据加载
    7. # n_samples 表示生成100个样本点 centers 生成6个中心点
    8. # cluster_std 聚类的标准差
    9. x,y=make_blobs(n_samples=100,centers=6,random_state=1234,cluster_std=0.6)
    10. plt.figure(figsize=(6,6))
    11. plt.scatter(x[:,0],x[:,1],c=y)
    12. plt.show()

     

    1. ## 2.算法实现
    2. ## 引用scipy的距离函数 默认欧式距离
    3. from scipy.spatial.distance import cdist
    4. class K_Means(object):
    5. # 初始化,参数n_clusters(K)聚类的类别 、max_iter最大迭代次数、初始质心centroids
    6. def __init__(self,n_clusters=6,max_iter = 300,centroids=[]):
    7. self.n_clusters=n_clusters
    8. self.max_iter=max_iter
    9. self.centroids=np.array(centroids,dtype=np.float64)
    10. # 训练模型方法,k-means聚类过程,传入原始数据
    11. # data是个二维举证
    12. def fit(self,data):
    13. # 假如没有制定初心质心,就随机选取data中的点作为初始质心
    14. if(self.centroids.shape == (0,)):
    15. ## 随机生成n_clusters个0到len(data)的索引值从data中获取数据
    16. self.centroids = data[ np.random.randint(0,data.shape[0],self.n_clusters),: ]
    17. #开始迭代
    18. for i in range(self.max_iter):
    19. # 1. 计算距离矩阵,得到的是一个100 * 6 的矩阵
    20. # 就是每个data的数据与不同质心点的距离
    21. distances = cdist(data,self.centroids)
    22. # 2. 对距离按由近到远排序,选取最近的质心点类别,作为当前点的分类
    23. c_index = np.argmin(distances,axis=1 )
    24. # 3. 对每一类数据进行均值计算,更新质心点坐标
    25. for i in range(self.n_clusters):
    26. # 首先排出掉没有出现在c_index的类别
    27. # 因为可能存在某个质心没有数据聚集到
    28. if i in c_index:
    29. # 选出所有列表是i的点,取data里面坐标的均值,更新第i个质心
    30. #data[c_index==i] 布尔索引,拿出来的是true的索引的值
    31. self.centroids[i] = np.mean(data[c_index==i],axis=0)
    32. # 实现预测方法
    33. def predict(self,samples):
    34. # 跟上面一样,先计算距离矩阵,然后选取距离最近的那个质心的类别
    35. distances = cdist(samples,self.centroids)
    36. c_index = np.argmin(distances,axis=1 )
    37. return c_index
    1. ### 3. 测试
    2. # 定义一个绘制子图函数
    3. def plotKMean(x,y,centroids,subplot,title):
    4. # 分配子图
    5. plt.subplot(subplot)
    6. plt.scatter(x[:,0],x[:,1],c='r')
    7. #画出质心点 s为size
    8. plt.scatter(centroids[:,0],centroids[:,1],c=np.array(range(6)),s=100)
    9. plt.title(title)
    10. kmeans = K_Means(max_iter = 300,centroids=np.array([[2,1],[2,2],[2,3],[2,4],[2,5],[2,6]]))
    11. plt.figure(figsize=(16,6))
    12. # 121 表示 1行2列的第一个子图
    13. plotKMean(x,y,kmeans.centroids,121,'Initial State')
    14. # 开始聚类
    15. kmeans.fit(x)
    16. plotKMean(x,y,kmeans.centroids,122,'Final State')
    17. # 预测新数据点的类别
    18. x_new = np.array([[0,0],[10,7]])
    19. y_pred= kmeans.predict(x_new)
    20. print(kmeans.centroids)
    21. #[[ 5.76444812 -4.67941789]
    22. # [-2.89174024 -0.22808556]
    23. # [-5.89115978 2.33887408]
    24. # [-4.53406813 6.11523454]
    25. # [-1.15698106 5.63230377]
    26. # [ 9.20551979 7.56124841]]
    27. print(y_pred)
    28. # [1 5]
    29. plt.scatter(x_new[:,0],x_new[:,1],s=100,c='black')

     

  • 相关阅读:
    12000字解读瑞幸咖啡:“异军突起”与“绝处逢生”的奥秘
    java基于ssm的汽车维修保养管理系统
    文件对比工具Beyond Compare 4(4.4.7) for Mac
    [Lingo编程极速入门]——基础01
    eyb:Vue学习1
    ES6知识点总结——学习网站及环境搭建
    Linux之Nignx及负载均衡&动静分离
    机器学习第七课--情感分析系统
    java:springboot 整理webSocket
    ceph 14.2.10 aarch64 非集群内 客户端 挂载块设备
  • 原文地址:https://blog.csdn.net/oracle8090/article/details/126004096