• SVM学习笔记


    一、SVM算法简介

    支持向量机(Support Vector Machine, SVM)是一类按监督学习(supervised learning)方式对数据进行二元分类的广义线性分类器(generalized linear classifier),其决策边界是对学习样本求解的最大边距超平面(maximum-margin hyperplane)。

    SVM使用铰链损失函数(hinge loss)计算经验风险(empirical risk)并在求解系统中加入了正则化项以优化结构风险(structural risk),是一个具有稀疏性和稳健性的分类器 。SVM可以通过核方法(kernel method)进行非线性分类,是常见的核学习(kernel learning)方法之一 。

    SVM在各领域的模式识别问题中有应用,包括人像识别 、文本分类 、手写字符识别 、生物信息学 等。

    SVM与KNN对比分析

    • svm,就像是在河北和北京之间有一条边界线,如果一个人居住在北京一侧就预测为北京人,在河北一侧,就预测为河北人。但是住在河北的北京人和住在北京的河北人就会被误判。
    • knn,就是物以类聚,人以群分。如果你的朋友里大部分是北京人,就预测你也是北京人。如果你的朋友里大部分是河北人,那就预测你是河北人。不管你住哪里。

    可惜河北和北京直接并不能以一条边界进行划分,如果用kernel trick(核函数、核技巧), SVM 也可以画出非线性边界。

    • knn没有训练过程,他的基本原理就是找到训练数据集里面离需要预测的样本点距离最近的k个值(距离可以使用比如欧式距离,k的值需要自己调参),然后把这k个点的label做个投票,选出一个label做为预测。对于KNN,没有训练过程。只是将训练数据与训练数据进行距离度量来实现分类。
    • svm需要超平面wx+b来分割数据集(此处以线性可分为例),因此会有一个模型训练过程来找到w和b的值。训练完成之后就可以拿去预测了,根据函数y=wx+b的值来确定样本点x的label,不需要再考虑训练集。对于SVM,是先在训练集上训练一个模型,然后用这个模型直接对测试集进行分类。
    • knn没有训练过程,但是预测过程需要挨个计算每个训练样本和测试样本的距离,当训练集和测试集很大时,预测效率感人。 svm有一个训练过程,训练完直接得到超平面函数,根据超平面函数直接判定预测点的label,预测效率很高
    • 两者调参过程不一样。 knn只有一个参数k,而svm的参数更多,在线性不可分的情况下(这种情况更普遍),有松弛变量的系数,有具体的核函数。

    总结:

    KNN优点:

    简单易于理解;无需训练,无需估计参数;准确性高;适合多标签问题
    KNN缺点:

    懒惰算法,预测慢,开销大;类的样本数不平衡时准确率受影响;可解释性差
    SVM优点:

    适合小样本、非线性、高维模式识别
    SVM缺点:

    对于大规模数据开销大;不合适多分类

    二、理论推导过程

    学习链接:
    svm原理从头到尾详细推导
    SVM详细讲解_知天易or逆天难的博客-CSDN博客_svm方法
    学习视频资料
    以电影为背景的短视频SVM讲解
    白话支持向量机
    SVM讲解短视频
    长视频,深入原理
    长视频,含有公式分析
    推导过程还在学习了解中

    三、百度飞桨案例实现(鸢尾花分类)

    1. 加载相关包
      在这里插入图片描述
    2. 加载数据,分割数据集(原始数据分类好了,这回打乱使其效果较好)

    在这里插入图片描述

    1. 构建SVM分类器,训练函数在这里插入图片描述

    2. 初始化分类器实例,训练模型在这里插入图片描述

    3. 展示训练结果及验证结果
      代码展示:

    # ======判断a,b是否相等计算acc的均值
    def show_accuracy(a, b, tip):
        acc = a.ravel() == b.ravel()
        print('%s Accuracy:%.3f' %(tip, np.mean(acc)))
        
    # 分别打印训练集和测试集的准确率 score(x_train, y_train)表示输出 x_train,y_train在模型上的准确率
    def print_accuracy(clf, x_train, y_train, x_test, y_test):
        print('training prediction:%.3f' %(clf.score(x_train, y_train)))
        print('test data prediction:%.3f' %(clf.score(x_test, y_test)))
        # 原始结果和预测结果进行对比 predict() 表示对x_train样本进行预测,返回样本类别
        show_accuracy(clf.predict(x_train), y_train, 'traing data')
        show_accuracy(clf.predict(x_test), y_test, 'testing data')
        # 计算决策函数的值 表示x到各个分割平面的距离
        print('decision_function:\n', clf.decision_function(x_train))
        
    def draw(clf, x):   
        iris_feature = 'sepal length', 'sepal width', 'petal length', 'petal width'
        # 开始画图 
        x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
        x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
        # 生成网格采样点
        x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]  
        # 测试点
        grid_test = np.stack((x1.flat, x2.flat), axis = 1)
        print('grid_test:\n', grid_test)
        # 输出样本到决策面的距离
        z = clf.decision_function(grid_test)
        print('the distance to decision plane:\n', z)
        grid_hat = clf.predict(grid_test)
        # 预测分类值 得到[0, 0, ..., 2, 2]
        print('grid_hat:\n', grid_hat)
        # 使得grid_hat 和 x1 形状一致
        grid_hat = grid_hat.reshape(x1.shape)
        cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['g', 'b', 'r'])
        
        plt.pcolormesh(x1, x2, grid_hat, cmap = cm_light) 
        plt.scatter(x[:, 0], x[:, 1], c=np.squeeze(y), edgecolor='k', s=50, cmap=cm_dark )
        plt.scatter(x_test[:, 0], x_test[:, 1], s=120, facecolor='none', zorder=10 )
        plt.xlabel(iris_feature[0], fontsize=20) 
        plt.ylabel(iris_feature[1], fontsize=20)
        plt.xlim(x1_min, x1_max)
        plt.ylim(x2_min, x2_max)
        plt.title('Iris data classification via SVM', fontsize=30)
        plt.grid()
        plt.show()
    
    # 4 模型评估
    print('-------- eval ----------')
    print_accuracy(clf, x_train, y_train, x_test, y_test)
    # 5 模型使用
    print('-------- show ----------')
    draw(clf, x) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

    在这里插入图片描述

  • 相关阅读:
    网页加载有哪些事件
    zip()并行迭代多个序列
    Ansible自动化运维工具
    Scikit-Learn支持向量机分类
    利用Helm在K8S上部署 PolarDB-X 集群(详细步骤--亲测!!!)
    Vue2.js迁移到Vue3.js的API变化
    MCU 的 TOP 15 图形GUI库:选择最适合你的图形用户界面(一)
    对话 ONES 联合创始人兼 CTO 冯斌:技术管理者如何打造一支自驱型团队?
    能快速构建和定制网络拓扑图的WPF开源项目-NodeNetwork
    60 个前端 Web 开发流行语你都知道哪些?
  • 原文地址:https://blog.csdn.net/weixin_61587867/article/details/126390057