• KNN实现鸢尾花分类



    前言

    如下提供了两种训练方式

    1. 常规训练 的话需要 自己去试那个K的值,一般试个 3、5、7、9 就行
    2. 网格搜索训练 可以让 机器自己去试这个K的值,训练结束后使用最好的模型预测即可
    3. N折交叉验证训练 会让训练量提升N倍,但是会最大化的利用已有数据进行训练和验证,一般来说折数多一些训练结果会变好,但也不宜过多,该方法 常用在数据量较少或者获取训练数据成本较高的情况

    一、安装sklearn

    pip install scikit-learn
    
    • 1

    二、常规训练

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import KNeighborsClassifier
    import pandas as pd
    
    
    if __name__ == '__main__':
    
        # 读取数据
        iris = load_iris()
        print(iris)
        # {
        #   'data': array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], ... [5.9, 3. , 5.1, 1.8]]),
        #   'target': array([0, 0, ... 2]),
        #   'frame': None,
        #   'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='
        #   'DESCR': '.. _iris_dataset:\n\nIris plants dataset ...\n\n|details-end|',
        #   'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],
        #   'filename': 'iris.csv',
        #   'data_module': 'sklearn.datasets.data'
        # }
    
        # 训练集、测试集拆分
        x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=233, test_size=0.2)
        print(x_train)
        # [[4.8 3.  1.4 0.3]
        #  [6.3 3.3 4.7 1.6]
        #       ...
        #  [5.1 3.4 1.5 0.2]]
        print(x_test)
        # [[5.2 2.7 3.9 1.4]
        #  [6.8 3.2 5.9 2.3]
        #       ...
        #  [5.7 2.6 3.5 1. ]]
        print(y_train)
        # [0 1 ... 0]
        print(y_test)
        # [1 2 ... 1]
    
        # 参数标准化【特征缩放到均值为0标准差为1】
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.fit_transform(x_test)
    
        # 【常规训练】初始化KNN并指定K为5
        model = KNeighborsClassifier(n_neighbors=5)
    
        # 训练
        model.fit(x_train, y_train)
    
        # 评估
        predict = model.predict(x_test)
        print(predict)
        # [1 2 ... 1]
    
        # 打印评价指标
        accuracy = model.score(x_test, y_test)
        print(accuracy)
        # 0.9666666666666667
    
    
    
    
    
    
    
    
    
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71

    三、网格搜素训练+N折交叉验证

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import KNeighborsClassifier
    import pandas as pd
    
    
    if __name__ == '__main__':
    
        # 读取数据
        iris = load_iris()
        print(iris)
        # {
        #   'data': array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], ... [5.9, 3. , 5.1, 1.8]]),
        #   'target': array([0, 0, ... 2]),
        #   'frame': None,
        #   'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='
        #   'DESCR': '.. _iris_dataset:\n\nIris plants dataset ...\n\n|details-end|',
        #   'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],
        #   'filename': 'iris.csv',
        #   'data_module': 'sklearn.datasets.data'
        # }
    
        # 训练集、测试集拆分
        x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=233, test_size=0.2)
        print(x_train)
        # [[4.8 3.  1.4 0.3]
        #  [6.3 3.3 4.7 1.6]
        #       ...
        #  [5.1 3.4 1.5 0.2]]
        print(x_test)
        # [[5.2 2.7 3.9 1.4]
        #  [6.8 3.2 5.9 2.3]
        #       ...
        #  [5.7 2.6 3.5 1. ]]
        print(y_train)
        # [0 1 ... 0]
        print(y_test)
        # [1 2 ... 1]
    
        # 参数标准化【特征缩放到均值为0标准差为1】
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.fit_transform(x_test)
    
        # # 【常规训练】初始化KNN并指定K为5
        # model = KNeighborsClassifier(n_neighbors=5)
    
        # 【网格搜索训练+N折交叉验证】初始化KNN随便指定一个K值,将 3, 5, 7, 9 分别进行4折交叉验证训练,总训练次数为 16 次
        model = KNeighborsClassifier(n_neighbors=1)
        model = GridSearchCV(model, param_grid={'n_neighbors': [3, 5, 7, 9]}, cv=4)
    
        # 训练
        model.fit(x_train, y_train)
    
    
        # 评估【使用最优模型评估】
        predict = model.best_estimator_.predict(x_test)
        print(predict)
        # [1 2 ... 1]
    
        # 打印评价指标
        print(f'best accuracy: {model.best_score_}')
        print(f'results: {model.cv_results_}')
    
        # accuracy = model.score(x_test, y_test)
        # print(accuracy)
        # # 0.9666666666666667
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
  • 相关阅读:
    Mysql8.x版本主从加读写分离(一) mysql8.x主从
    太绝了!这份Python爬虫入门『最强教程』当之无愧
    计算机毕业设计Java城市交通海量数据管理系统(源码+系统+mysql数据库+lw文档)
    MP3算法及代码例程
    Spring6学习技术|Junit
    1000套安卓(Android)毕业设计(带论文)、大作业、实例快速下载 (Android Studio)
    [iOS]-单例模式\通知\代理
    bootstrap系列-1.简单的Demo
    番外--Task2:
    以写Hbase表的方式更新Phoenix索引
  • 原文地址:https://blog.csdn.net/weixin_43721000/article/details/133850999