• K-Means++代码实现


    K-Means++代码实现

    数据集
    https://download.csdn.net/download/qq_43629083/87246495

    import pandas as pd
    import numpy as np
    import random
    import math
    %matplotlib inline
    from matplotlib import pyplot as plt
    
    # 按文件名读取整个文件
    data = pd.read_csv('data.csv')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    class MyKmeansPlusPlus:
        def __init__(self, k, max_iter = 10):
            self.k = k
            # 最大迭代次数
            self.max_iter = max_iter
            # 训练集
            self.data_set = None
            # 结果集
            self.labels = None
        
        '''
        计算两点间的欧拉距离
        '''
        def euler_distance(self, point1, point2):
            distance = 0.0
            for a, b in zip(point1, point2):
                distance += math.pow(a - b, 2)
            return math.sqrt(distance)
        
        '''
        计算样本中的每一个样本点与已经初始化的聚类中心之间的距离,并选择其中最短的距离
        '''
        def nearest_distance(self, point, cluster_centers):
            min_distance = math.inf
            dim = np.shape(cluster_centers)[0]
            for i in range(dim):
                # 计算point与每个聚类中心的距离
                distance = self.euler_distance(point, cluster_centers[i])
                # 选择最短距离
                if distance < min_distance:
                    min_distance = distance
            return min_distance
        
        '''
        初始化k个聚类中心
        '''
        def get_centers(self):
            dim_m, dim_n = np.shape(self.data_set)
            cluster_centers = np.array(np.zeros(shape = (self.k, dim_n)))
            #随机初始化第一个聚类中心点
            index = np.random.randint(0, dim_m)
            cluster_centers[0] = self.data_set[index]
            
            # 初始化一个距离序列
            distances = [0.0 for _ in range(dim_m)]
            
            for i in range(1, self.k):
                print("i = ", i)
                sum_all = 0.0
                for j in range(dim_m):
                    # 对每一个样本找到最近的聚类中心点
                    distances[j] = self.nearest_distance(self.data_set[j], cluster_centers[0:i])
                    # 将所有最短距离相加
                    sum_all += distances[j]
                # 取得sum_all之间的随机值
                sum_all *= random.random()
                # 以概率获得距离最远的样本点作为聚类中心
                for id, dist in enumerate(distances):
                    sum_all -= dist
                    if sum_all > 0:
                        continue
                    cluster_centers[i] = self.data_set[id]
                    break;
            return cluster_centers
        
        '''
        确定非中心点与哪个中心点最近
        '''
        def get_closest_index(self, point, centers):
            # 初始值设为最大
            min_dist = math.inf
            label = -1
            # enumerate() 函数同时列出数据和数据下标
            for i, center in enumerate(centers):
                dist = self.euler_distance(center, point)
                if dist < min_dist:
                    min_dist = dist
                    label = i
            return label
        
        '''
        更新中心点
        '''
        def update_centers(self):
            # k类点分别存
            points_label = [[] for i in range(self.k)]
            for i, label in enumerate(self.labels):
                points_label[label].append(self.data_set[i])
            centers = []
            for i in range(self.k):
                centers.append(np.mean(points_label[i], axis = 0))
            return centers
        
        '''
        判断是否停止迭代,新中心点与旧中心点一致或者达到设置的迭代最大值则停止
        '''
        def stop_iter(self, old_centers, centers, step):
            if step > self.max_iter:
                return True
            return np.array_equal(old_centers, centers)
        
        '''
        模型训练
        '''
        def fit(self, data_set):
            self.data_set = data_set.drop(['labels'], axis = 1)
            self.data_set = np.array(self.data_set)
            point_num = np.shape(data_set)[0]
            # 初始化结果集
            self.labels = data_set.loc[:, 'labels']
            self.labels = np.array(self.labels)
            
            # 初始化k个聚类中心点
            centers = self.get_centers()
            
            # 保存上一次迭代的中心点
            old_centers = []
            # 当前迭代次数
            step = 0
            flag = False
            while not flag:
                # 存储 旧的中心点
                old_centers = np.copy(centers)
                # 迭代次数+1
                step += 1
                print("current iteration: ", step)
                print("current centers: ", old_centers)
                # 本次迭代 各个点所属类别(即该点与哪个中心点最近)
                for i, point in enumerate(self.data_set):
                    self.labels[i] = self.get_closest_index(point, centers)
                # 更新中心点
                centers = self.update_centers()
                # 迭代是否停止的标志
                flag = self.stop_iter(old_centers, centers, step)
                centers = np.array(centers)
                fig = plt.figure()
                label0 = plt.scatter(self.data_set[:, 0][self.labels == 0], self.data_set[:, 1][self.labels == 0])
                label1 = plt.scatter(self.data_set[:, 0][self.labels == 1], self.data_set[:, 1][self.labels == 1])
                label2 = plt.scatter(self.data_set[:, 0][self.labels == 2], self.data_set[:, 1][self.labels == 2])
                plt.scatter(old_centers[:, 0], old_centers[:, 1], marker='^', edgecolor='black', s=128)
    
                plt.title('labeled data')
                plt.xlabel('V1')
                plt.ylabel('V2')
                plt.legend((label0, label1, label2), ('label0', 'label1', 'label2'))
                plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    myKmeansPP = MyKmeansPlusPlus(3)
    
    • 1
    myKmeansPP.fit(data)
    
    • 1

    current iteration: 1
    current centers:
    [[55.97659 75.71833 ]
    [43.75808 67.45812 ]
    [71.72321 -7.872746]]

    请添加图片描述

    current iteration: 2
    current centers:
    [[55.83404759 70.21560931]
    [30.35261288 47.71518861]
    [50.15798861 -5.34769581]]

    请添加图片描述

    current iteration: 3
    current centers:
    [[47.66230967 65.1238036 ]
    [22.93488 39.05383154]
    [52.52023009 -6.18734425]]

    请添加图片描述

    current iteration: 4
    current centers:
    [[42.96329079 61.70702396]
    [12.28521822 20.36196405]
    [63.73622886 -9.02914858]]

    请添加图片描述

    current iteration: 5
    current centers:
    [[ 40.8388755 59.95703427]
    [ 9.62033389 11.15366963]
    [ 69.77599323 -10.09654797]]

    请添加图片描述

  • 相关阅读:
    springboot+自行车网上商城 毕业设计-附源码130948
    java jvm用到的各种工具
    二叉树的Morris遍历
    去除社区版本 idea 没有添加括号的报红
    docker安装
    mybatis config 配置
    列表—list 使用
    前端静态页面基本开发思路(二)
    【C++ 程序】级数求和
    SpringCloud系列——Ribbon day2-2
  • 原文地址:https://blog.csdn.net/qq_43629083/article/details/128198222