目录
为MNIST数据集构建一个分类器,并在测试集上达成超过97%的精度。
下面进行代码展示:
- #1、获取MNIST数据集
- from sklearn.datasets import fetch_openml
- mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
-
- #2、划分数据集
- import numpy as np
-
- X, y = mnist["data"], mnist["target"]
-
- #MNIST默认划分的训练集和测试集
- X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
-
- #数据重新洗牌,防止算法对训练实例的顺序敏感
- shuffle_index = np.random.permutation(60000)#生成一个随机排列的数组
- X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
-
-
- #注意对自己电脑硬件不自信不要运行下面代码,以防蓝屏,可以了解一下思想
- from sklearn.neighbors import KNeighborsClassifier
- from sklearn.model_selection import GridSearchCV
-
- param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 5]}]
-
- knn_clf = KNeighborsClassifier()
- grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3, n_jobs=-1)
- grid_search.fit(X_train, y_train)
找到合适的超参数:
grid_search.best_params_
运行结果如下:
{'n_neighbors': 4, 'weights': 'distance'}
得分:
grid_search.best_score_
运行结果如下:
0.97325
预测精度:
- from sklearn.metrics import accuracy_score
-
- y_pred = grid_search.predict(X_test)
- accuracy_score(y_test, y_pred)
运行结果如下:
0.9714
我们就代码中包含的知识点进行简单讲解:
对给定的数组重新排列。
- import numpy as np
-
- arr = np.random.permutation(6)
- print(arr)
运行结果如下:
[2 5 4 0 3 1]
另外对数组进行重新排列的还包括:np.random.shuffle(arr)
- arr = np.arange(6)
- print(arr)
- np.random.shuffle(arr)
- print(arr)
运行结果如下:
- [0 1 2 3 4 5]
- [4 5 1 2 0 3]
中文文档说明:sklearn.neighbors.KNeighborsClassifier-scikit-learn中文社区
英文文档说明:sklearn.neighbors.KNeighborsClassifier — scikit-learn 1.1.2 documentation
我们看一下文档中参数:
sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, *, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)
参数 | 说明 |
---|---|
n_neighbors | int, default=5 默认情况下用于kneighbors查询的近邻数 |
weights | {‘uniform’, ‘distance’} or callable, default=’uniform’ 预测中使用的权重函数。 可能的值: “uniform”:统一权重。 每个邻域中的所有点均被加权。 “distance”:权重点与其距离的倒数。 在这种情况下,查询点的近邻比远处的近邻具有更大的影响力。 [callable]:用户定义的函数,该函数接受距离数组,并返回包含权重的相同形状的数组。 |
algorithm | {‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}, default=’auto’ 用于计算最近临近点的算法: “ ball_tree”将使用BallTree kd_tree”将使用KDTree “brute”将使用暴力搜索。 “auto”将尝试根据传递给fit方法的值来决定最合适的算法。 注意:在稀疏输入上进行拟合将使用蛮力覆盖此参数的设置。 |
leaf_size | int, default=30 叶大小传递给Bal |