目录
1.find_closest_centroids :寻找最近的质心
3.kMeans_init_centroids :随机初始化质心
2.实例:使用k-means算法对图片像素进行压缩 255色压缩到1色
- import numpy as np
- import matplotlib.pyplot as plt
- #寻找最近的质心
- def find_closest_centroids(X, centroids):
-
- K = centroids.shape[0]
- idx = np.zeros(X.shape[0], dtype=int)
-
- for i in range(X.shape[0]):
- distance = []
- for j in range(K):
- norm_ij = np.linalg.norm(X[i] - centroids[j])
- distance.append(norm_ij)
- idx[i] = np.argmin(distance)
-
- return idx
- def compute_centroids(X, idx, K):
-
- m, n = X.shape
- centroids = np.zeros((K, n))
-
- for k in range(K):
- points = X[idx == k]
- centroids[k] = np.mean(points, axis = 0)
-
- return centroids
- def kMeans_init_centroids(X, K):
-
- randidx = np.random.permutation(X.shape[0])
- centroids = X[randidx[:K]]
-
- return centroids
- def run_kMeans(X, initial_centroids, max_iters=10, plot_progress=False):
-
- m, n = X.shape
- K = initial_centroids.shape[0]
- centroids = initial_centroids
- previous_centroids = centroids
- idx = np.zeros(m)
-
- for i in range(max_iters):
- print("K-Means iteration %d/%d" % (i, max_iters-1))
- idx = find_closest_centroids(X, centroids)
- if plot_progress:
- plot_progress_kMeans(X, centroids, previous_centroids, idx, K, i)
- previous_centroids = centroids
-
- centroids = compute_centroids(X, idx, K)
- plt.show()
- return centroids, idx
- #进行图片压缩
- #读取图片
- original_img = plt.imread('bird_small.png')
- #数据标准化 使像素值全部落在0-1之间
- original_img = original_img / 255
-
- X_img = np.reshape(original_img, (original_img.shape[0] * original_img.shape[1], 3))
- K = 16
- max_iters = 10
-
- initial_centroids = kMeans_init_centroids(X_img, K)
- centroids, idx = run_kMeans(X_img, initial_centroids, max_iters)
- X_recovered = centroids[idx, :]
- X_recovered = np.reshape(X_recovered, original_img.shape)
-
- fig, ax = plt.subplots(1,2, figsize=(8,8))
- plt.axis('off')
-
- #展示原图片
- ax[0].imshow(original_img*255)
- ax[0].set_title('Original')
- ax[0].set_axis_off()
-
- #展示压缩后图片
- ax[1].imshow(X_recovered*255)
- ax[1].set_title('Compressed with %d colours'%K)
- ax[1].set_axis_off()
