• 机器学习之感知机原理及Python实现


    机器学习算法感知机(perceptron)。感知机是一种较为简单的二分类模型,但由简至繁,感知机却是神经网络和支持向量机的基础。感知机旨在学习能够将输入数据划分为+1/-1的线性分离超平面,所以说整体而言感知机是一种线性模型。因为是线性模型,所以感知机的原理并不复杂,本节和大家来看一下感知机的基本原理和Python实现。

    感知机原理

         假设输入x表示为任意实例的特征向量,输出y={+1,-1}为实例的类别。感知机定义由输入到输出的映射函数如下:

    图片

         其中sign符号函数为:

    图片

         w和b为感知机模型参数,也是感知机要学习的东西。w和b构成的线性方程wx+b=0极为线性分离超平面。

    图片

         假设数据是线性可分的,当然有且仅在数据线性可分的情况下,感知机才能奏效。感知机模型简单,但这也是其缺陷之一。所谓线性可分,也即对于任何输入和输出数据都存在某个线性超平面wx+b=0能够将数据集中的正实例点和负实例点完全正确的划分到超平面两侧,这样数据集就是线性可分的。

         感知机的训练目标就是找到这个线性可分的超平面。为此,定义感知机模型损失函数如下:

    图片

         要优化这个损失函数,可采用梯度下降法对参数进行更新以最小化损失函数。计算损失函数关于参数w和b的梯度如下:

    图片

         由上可知完整的感知机算法包括参数初始化、对每个数据点判断其是否误分,如果误分,则按照梯度下降法更新超平面参数,直至没有误分类点。

    图片

         以上便是感知机算法的基本原理。当然这里说的感知机仅限于单层的感知机模型,仅适用于线性可分的情况。对于线性不可分的情形,笔者将在后续的神经网络和感知机两讲详细介绍。

    感知机算法实现

         完整的感知机算法包括参数初始化、模型主体、参数优化等部分,我们便可以按照这个思路来实现感知机算法。在正式写模型之前,我们先用sklearn来准备一下示例数据。

    1. # 导入相关库
    2. import pandas as pd
    3. import numpy as np
    4. from sklearn.datasets import load_iris
    5. import matplotlib.pyplot as plt
    6. # 导入iris数据集
    7. iris = load_iris()
    8. df = pd.DataFrame(iris.data, columns=iris.feature_names)
    9. df['label'] = iris.target
    10. df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
    11. # 绘制散点图
    12. plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='red', label='0')
    13. plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='green', label='1')
    14. plt.xlabel('sepal length')
    15. plt.ylabel('sepal width')
    16. plt.legend();

    图片

    1. # 取两列数据并将并将标签转化为1/-1
    2. data = np.array(df.iloc[:100, [0, 1, -1]])
    3. X, y = data[:,:-1], data[:,-1]
    4. y = np.array([1 if i == 1 else -1 for i in y])

     下面正式开始模型部分。先定义一个参数初始化函数:

    1. # 定义参数初始化函数
    2. def initilize_with_zeros(dim):
    3. w = np.zeros(dim)
    4. b = 0.0
    5. return w, b
    然后定义sign符号函数:
    1. # 定义sign符号函数
    2. def sign(x, w, b):
    3. return np.dot(x,w)+b
    最后定义模型训练和优化部分:
    1. # 定义感知机训练函数
    2. def train(X_train, y_train, learning_rate):
    3. # 参数初始化
    4. w, b = initilize_with_zeros(X_train.shape[1])
    5. # 初始化误分类
    6. is_wrong = False
    7. while not is_wrong:
    8. wrong_count = 0
    9. for i in range(len(X_train)):
    10. X = X_train[i]
    11. y = y_train[i]
    12. # 如果存在误分类点
    13. # 更新参数
    14. # 直到没有误分类点
    15. if y * sign(X, w, b) <= 0:
    16. w = w + learning_rate*np.dot(y, X)
    17. b = b + learning_rate*y
    18. wrong_count += 1
    19. if wrong_count == 0:
    20. is_wrong = True
    21. print('There is no missclassification!')
    22. # 保存更新后的参数
    23. params = {
    24. 'w': w,
    25. 'b': b
    26. }
    27. return params
    对示例数据进行训练:
    1. params = train(X, y, 0.01)
    2. params

    图片

    最后对训练结果进行可视化,绘制模型的决策边界:​​​​​​​

    1. x_points = np.linspace(4, 7, 10)
    2. y_hat = -(params['w'][0]*x_points + params['b'])/params['w'][1]
    3. plt.plot(x_points, y_hat)
    4. plt.plot(data[:50, 0], data[:50, 1], color='red', label='0')
    5. plt.plot(data[50:100, 0], data[50:100, 1], color='green', label='1')
    6. plt.xlabel('sepal length')
    7. plt.ylabel('sepal width')
    8. plt.legend()

    图片

    最后,我们也可以建一个perceptron类来方便调用。对上述代码进行整理:

    1. class Perceptron:
    2. def __init__(self):
    3. pass
    4. def sign(self, x, w, b):
    5. return np.dot(x, w) + b
    6. def train(self, X_train, y_train, learning_rate):
    7. # 参数初始化
    8. w, b = self.initilize_with_zeros(X_train.shape[1])
    9. # 初始化误分类
    10. is_wrong = False
    11. while not is_wrong:
    12. wrong_count = 0
    13. for i in range(len(X_train)):
    14. X = X_train[i]
    15. y = y_train[i]
    16. # 如果存在误分类点
    17. # 更新参数
    18. # 直到没有误分类点
    19. if y * self.sign(X, w, b) <= 0:
    20. w = w + learning_rate*np.dot(y, X)
    21. b = b + learning_rate*y
    22. wrong_count += 1
    23. if wrong_count == 0:
    24. is_wrong = True
    25. print('There is no missclassification!')
    26. # 保存更新后的参数
    27. params = {
    28. 'w': w,
    29. 'b': b
    30. }
    31. return params
    
                    
  • 相关阅读:
    STM32中的加速度计驱动程序与姿态控制实现
    Qt实现桌面画线、标记,流畅绘制,支持鼠标和多点触控绘制
    linux基本指令(下)
    JAVA生成20位LONG型UUID
    SSM+SB面试题收集
    【推荐】数字化转型和案例及IT规划资料整理合集
    闭关之现代 C++ 笔记汇总(二):特性演化
    μC/OS-II---计时器管理2(os_tmr.c)
    Oracle DBlink使用方法
    QLineEdit 使用QValidator 限制各种输入
  • 原文地址:https://blog.csdn.net/huanxiajioabu/article/details/133135825