• 脉冲神经网络:MATLAB实现脉冲神经网络(Spiking Neural Network,SNN) 用于图像分类(提供MATLAB代码)


    一、脉冲神经网络

    脉冲神经网络 (Spiking Neural Network,SNN) ,是第三代神经网络。其旨在弥合神经科学和机器学习之间的差距,使用最拟合生物神经元机制的模型来进行计算,更接近生物神经元机制。脉冲神经网络与目前流行的神经网络和机器学习方法有着根本上的不同。SNN 使用脉冲——这是一种发生在时间点上的离散事件——而非常见的连续值。每个峰值由代表生物过程的微分方程表示出来,其中最重要的是神经元的膜电位。本质上,一旦神经元达到了某一电位,脉冲就会出现,随后达到电位的神经元会被重置。对此,最常见的模型是 Leaky Integrate-And-Fire (LIF) 模型。此外,SNN 通常是稀疏连接的,并会利用特殊的网络拓扑。

    二、数据集简介

    训练集共有十张光学字符图片构成分别是1,2,3,4,5,6,7,8,9,0。其对应类别可表示为:

    1: 1 0 0 0 0 0 0 0 0 0

    2: 0 1 0 0 0 0 0 0 0 0

    3: 0 0 1 0 0 0 0 0 0 0

    4: 0 0 0 1 0 0 0 0 0 0

    5: 0 0 0 0 1 0 0 0 0 0

    6: 0 0 0 0 0 1 0 0 0 0

    7: 0 0 0 0 0 0 1 0 0 0

    8: 0 0 0 0 0 0 0 1 0 0

    9: 0 0 0 0 0 0 0 0 1 0

    0: 0 0 0 0 0 0 0 0 0 1

    原始图像(训练集):
    在这里插入图片描述

    将上述10个光学字符图像编码成时间脉冲:

    在这里插入图片描述

    含噪图像(测试集):
    在这里插入图片描述

    三、MATLAB实现

    3.1部分代码如下:

    Tmax=30;%最大训练次数(可以自己修改)
    spiking = cell2mat(struct2cell(load('spiking.mat')));%已经为时间脉冲的10张图像
    label1 = 0;%标签1为不点火,若测试样本点火则调整权重
    label2 = 1;%标签2为点火,若测试样本点火则调整权重
    % 随机生成10组权重值 
    w1 = rand(1,400);
    w2 = rand(1,400);
    w3 = rand(1,400);
    w4 = rand(1,400);
    w5 = rand(1,400);
    w6 = rand(1,400);
    w7 = rand(1,400);
    w8 = rand(1,400);
    w9 = rand(1,400);
    w10 = rand(1,400);
    W = [w1; w2; w3; w4; w5; w6; w7; w8; w9; w10];
    %训练十个输出神经元识别image1~10,使对应序号神经元点火,其他不点火
    for k = 1 : 1 : Tmax 
        fprintf('第%d次训练:\n',k);
        for i = 1 : 1 : 10 %10个输出神经元
            for j = 1 : 1 : 10 %10个脉冲时间序列 
                [val,maxT,maxU] = tempotron(spiking(j,:), W(i,:));
                if j == i
                    if val == 0 %理应为1,但为0,与标签不符合,则需要更新权值
                        W(i,:) = train(spiking(j,:), W(i,:), label2);
                       [val,maxT,maxI] = tempotron(spiking(j,:),W(i,:));
                    end
                else
                    if val == 1 %理应为0,但为1,与标签不符合,则需要更新权值
                        W(i,:) = train(spiking(j,:), W(i,:), label1);
                        [val,maxT,maxI] = tempotron(spiking(j,:),W(i,:));
                        
                    end                
                end
            end
        end
        %% 预测结果
        CorrectRate=0;
        for i = 1 : 1 : 10 %10个输出神经元
            DataPre=zeros(1,10);%预测值
            fprintf('%d:\t',i);
            for j = 1 : 1 : 10 %10个脉冲时间序列
                [val,maxT,maxU] = tempotron(spiking(j,:), W(i,:));
                fprintf('%d\t',val);
                DataPre(j)=val;
            end
            fprintf('\n');
            %% 判断预测分类是否正确
            if (sum(DataPre)==1)
                if (find(DataPre==1)==i)
                CorrectRate=CorrectRate+1;%分类准确则加1
                end
            end
        end
    fprintf('正确率:%f:\n',CorrectRate/10);
    Curve(k)=CorrectRate/10;
    end
    save W_new W;
    figure
    plot(Curve,'r-*')
    xlabel('训练次数')
    ylabel('准确率')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    最大训练次数为30次,训练集上准确率随迭代次数的变化图如下:

    在这里插入图片描述

    可以看出SNN在训练集上准确率达到90%

    3.2测试集上的预测效果:

    预测值为:

    1: 1 0 0 0 0 0 0 0 0 0 (正确)

    2: 0 1 0 0 0 0 1 0 0 0 (错误

    3: 0 0 1 0 0 0 0 0 0 0 (正确)

    4: 0 0 0 1 0 0 0 0 0 0 (正确)

    5: 0 0 0 0 1 0 0 0 0 0 (正确)

    6: 0 0 0 0 0 1 0 0 0 0 (正确)

    7: 0 0 0 0 0 0 1 0 0 0 (正确)

    8: 0 0 0 1 0 0 0 1 0 0 (错误

    9: 0 0 0 0 0 0 0 0 1 0 (正确)

    0: 0 0 0 0 0 0 0 0 0 1 (正确)

    由此可以看出,SNN在测试集上的准确率为80%

    四、参考代码见博主微信朋友圈

  • 相关阅读:
    two point(双指针)
    C陷阱——数组越界引发的死循环问题
    C++包含整数各位重组
    力扣shell刷题
    海豚调度系列之:任务类型——SQL节点
    Python解离散数学
    pod&node选择部署策略: nodeSelector和nodeAffinity
    编辑任何场景! 3DitScene:通过语言引导的解耦 Gaussian Splatting开源来袭!
    什么是无线传输技术,如Wi-Fi、蓝牙和NFC的特点和应用场景
    swagger在线api文档搭建指南
  • 原文地址:https://blog.csdn.net/weixin_46204734/article/details/125510128