• 支持向量机:原理与python案例


    支持向量机浅析

    支持向量机(SVM,support vector machine)是一种二类分类模型,其基本模型定义为特征空间上的间隔最大的线性分类器,其学习策略便是间隔最大化,最终可转化为一个凸二次规划问题的求解。

    线性分类器

    线性回归模型的公式为:
    $ y=w^T X +b$

    广义线性回归模型的公式为:
    $ y=g(w^T X +b)$

    分类任务的预测值是离散的,比如二分类问题,可以用0和1来表示两个类别,通过样本训练,即可得到线性分类器

    间隔最大超平的含义

    我们的任务是找到一个线性分类器( w T X + b w^T X +b wTX+b),将样本分开,以图中二维坐标为例,能完成分割任务的有无数条线,如何确定一个最近的线性分类器呢?

    泛化性能是分类器好坏的一个重要指标。图中所观察到的样本都是训练样本,我们希望面对未来的样本数据,分类器也有良好性能。

    在这里插入图片描述

    一般而言,相同类别样本的距离较近,同时根据结构风险最小化原则,线性分类器的决策边界最大,则边界的泛化误差最小。图中B2的决策边界到B2的距离明显小于B1,泛化性能较差。寻找间隔最大的超平面,实际上是找部分关键样本点,使得它们到超平面的距离最大。

    在这里插入图片描述

    线性支持向量机原理

    训练数据集: D={(x_1,y_1),(x_2,y_2),…,(x_n,y_n)},y{-1,+1}

    线性分类器决策边界的线性方程:$ y=w^T X +b$,其中w表示超平面的法向量,决定了决策边界的方向,b表示位移量,决定了决策边界与原点间的距离。

    训练数据集中的样本点,如果为正类,那么 y i = + 1 y_i = +1 yi=+1,$ w^T X +b > 0 ; 如果为负类,那么 ;如果为负类,那么 ;如果为负类,那么y_i = -1 , , , w^T X +b < 0$;

    在训练过程中,我们可以不断的调整决策边界的超参数w和b,总可以得到:
    { w T x i + b > = + 1 , y i = + 1 ; w T x i + b < = − 1 , y i = − 1 , {wTxi+b>=+1,yi=+1;wTxi+b<=1,yi=1,

    {wTxi+b>=+1,yi=+1;wTxi+b<=1,yi=1,
    {wTxi+b>=+1,yi=+1;wTxi+b<=1,yi=1,

    此时,距离决策边界最近的样本点,刚好使得上式中等号成立,这些关键样本点,就叫做支持向量,两个异类支持向量(关键样本点)之间的距离(margin)称为决策边界的“边缘”,这个边缘的距离就是我们要找的最大间隔,决策边界B1对应的一组平行的超平面(b11,b12)之间的距离,将间隔记为 γ \gamma γ

    将样本点 x 1 , X 2 x_1,X_2 x1,X2 带入上式,得:
    { b 11 : w T x 1 + b = 1 ; b 12 : w T x 2 + b = − 1 , {b11:wTx1+b=1;b12:wTx2+b=1,

    {b11:wTx1+b=1;b12:wTx2+b=1,
    {b11:wTx1+b=1;b12:wTx2+b=1,两式相减得:$ w^T(x_1-x_2)=2 , ( , ( ,x_1-x_2就是x_1x_2$的模乘上夹角余弦值),这里再变换最终得到 ∥ w ∥ × γ = 2 \parallel w\parallel \times\gamma=2 w×γ=2

    支持向量机的学习就是,寻找参数w,b,使得 2 ∥ w ∥ \frac {2}{\parallel w\parallel} w2取得最大值。等价于寻找参数w,b,使得 1 2 ∥ w ∥ 2 \frac {1}{2}\parallel w\parallel^2 21w2取得最小值,优化函数为:$ y_i(w^T x_i +b)>=1,i =1,2,…,n$,具体优化函数的求解可以使用拉格朗日乘子法求解。

    在这里插入图片描述

    python实现案例

    案例根据y=0.5x+0.5,y=2.1x+6.5两个函数为基准,引入了较大的随机误差后,各生成了100个样本点。生成的样本有明显的线性边界,我们尝试使用支持向量机模型,找出决策边界,并进行绘制软间隔分类决策边界。

    在这里插入图片描述

    from sklearn.svm import LinearSVC
    import numpy as np
    import pandas as pd
    import random
    import  matplotlib.pyplot as plt #类似 MATLAB 中绘图函数的相关函数
    import seaborn as sns
    
    np.random.seed(2)
    count=100
    data=[]
    for i in range(count):
        x1=np.random.normal(0.00,0.55)
        res1=x1*0.1+0.5+np.random.normal(0.00,0.9)
        data.append([x1,res1,1])
        
        x2=np.random.normal(0.00,0.55)
        res2=x2*2.1+6.5+np.random.normal(0.00,0.9)
        data.append([x2,res2,0])
    
    data =pd.DataFrame(data)
    # print(data[data[2]==1])
    x1_data=np.array(data[0])
    x2_data=np.array(data[1])
    plt.scatter(x1_data,x2_data,c=data[2])
    plt.show()
    
    
    #Pipeline通过将这些数据处理步骤结合成一条算法链,以更加高效地完成整个机器学习流程
    svm_clf = Pipeline(( ("scaler", StandardScaler()),
                         ("linear_svc", LinearSVC(C=1, loss="hinge")) ,))
    
    # 调用linear_svc
    svm_clf = svm_clf.fit(data.iloc[:,:2], data[2])
    
    # SVC求解可视化函数
    def decision_boundary( model) :
        # 取出两个坐标轴的上下限
        xmin, xmax, ymin, ymax =x1_data.min(), x1_data.max(), x2_data.min(), x2_data.max()
    
        # 坐标轴等分为50份,共可创建50x50=2500个点
        xloc = np.linspace(xmin, xmax, 50)
        yloc = np.linspace(ymin, ymax, 50)
    
        # 相当于一个数据复制的作用,将shape=(n,)的数据变为shape=(n,n)
        xloc, yloc = np.meshgrid(xloc, yloc)
    
        # 组合坐标点coordinate, 将(n, n)的数据展开成(n*n, )的数据在组合为坐标
        coo = np.vstack([xloc.ravel(), yloc.ravel()]).T
        
        # 通过decision_function函数计算出每个点到决策边界的距离
        dis = model.decision_function(coo)
        
        # contour要求X,Y,Z具有相同的维度,所以需要将预测结果reshape
        dis = dis.reshape(xloc.shape)
        
        # 画出原始数据的散点图
        plt.scatter(x1_data,x2_data, c=y)
        
        # 添加决策边界和两个超平面
        plt.contour(xloc, yloc, dis, alpha=.8, linestyles=['-', '--', '-'], levels=[-1, 0, 1])
        
        plt.show()
        
        
    
    decision_boundary(svm_clf)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
  • 相关阅读:
    4G DTU流量不要钱!再也不用买卡充值啦!
    vue使用原生video标签基本功能(不含样式)
    Cholesterol-PEG-Maleimide,CLS-PEG-MAL,胆固醇-聚乙二醇-马来酰亚胺一种修饰性PEG
    mysql基于Java web的电动车销售平台毕业设计源码201524
    高性能MySQL实战第09讲:如何做到MySQL的高可用?
    leetcode 1812
    Runtime——探索类,对象,分类本质
    【JAVA-Day03】JDK安装与IntelliJ IDEA安装、配置环境变量
    【好书推荐】计算机网络:自顶向下方法(第七版)
    掌握这四步,月收入1万+的自媒体人可能就是你
  • 原文地址:https://blog.csdn.net/zzh1464501547/article/details/126801279