• 基于SNN脉冲神经网络的Hebbian学习训练过程matlab仿真


    目录

    一、理论基础

    二、案例背景

    1.问题描述

    2.思路流程

    三、部分MATLAB仿真

    四、仿真结论分析

    五、参考文献


    一、理论基础

            近年来,深度学习彻底改变了机器学习领域,尤其是计算机视觉。在这种方法中,使用反向传播以监督的方式训练深层(多层)人工神经网络(ANN)。虽然需要大量带标签的训练样本,但是最终的分类准确性确实令人印象深刻,有时甚至胜过人类。人工神经网络中的神经元的特征在于单个、静态、连续值的激活。然而生物神经元使用离散的脉冲来计算和传输信息,并且除了脉冲发放率外,脉冲时间也很重要。因此脉冲神经网络(SNN)在生物学上比ANN更现实,并且如果有人想了解大脑的计算方式,它无疑是唯一可行的选择。 SNN也比ANN更具硬件友好性和能源效率,因此对技术,尤其是便携式设备具有吸引力。但是训练深度SNN仍然是一个挑战。脉冲神经元的传递函数通常是不可微的,从而阻止了反向传播。在这里,我们回顾了用于训练深度SNN的最新监督和无监督方法,并在准确性、计算成本和硬件友好性方面进行了比较。目前的情况是,SNN在准确性方面仍落后于ANN,但差距正在缩小,甚至在某些任务上可能消失,而SNN通常只需要更少的操作。

           SNN中的无监督学习通常将STDP纳入学习机制。生物STDP的最常见形式具有非常直观的解释。如果突触前神经元在突触后神经元之前不久触发(大约10毫秒),则连接它们的权重会增加。如果突触后神经元在突触后神经元后不久触发,则时间事件之间的因果关系是虚假的,权重会减弱。增强称为长时程增强(LTP),减弱称为长时程抑制(LTD)。短语“长时程”用于区分实验中观察到的几毫秒范围内的非常短暂的影响。

          下面的公式是通过拟合实验数据对一对脉冲进行了实验上最常见的STDP规则的理想化。
         

           以上公式中的第一种情况描述LTP,而第二种情况描述LTD。效果的强度由衰减指数调制,衰减指数的大小由突触前和突触后脉冲之间的时间常数比例时间差控制。人工SNN很少使用此确切规则。他们通常使用变体,以实现更多的简单性或满足便利的数学特性。

    二、案例背景

    1.问题描述

            SNN神经网络的学习方法也不是很好,作为传统的基于速率的网络而发展,使用反向传播学习算法。使用高效的Hebbian学习方法:棘突神经元网络的稳态。类似于STDP,尖峰之间的计时用于突触修饰。内稳态确保了突触权重是有界的学习是稳定的。赢家通吃机制也很重要实施以促进输出之间的竞争性学习神经元。我们已经在一个C++对象中实现了这个方法面向对象的代码(称为CSpike)。我们已经在四个服务器上测试了代码Gabor滤波器的图像,并发现钟形调谐曲线使用不同类型的Gabor滤波器的36个测试集图像方向。这些钟形曲线与这些曲线相似实验上观察到的V1和MT/V5区域哺乳动物的大脑。

    2.思路流程

    SNN即目前的最新的第三代神经网络,具体的仿真步骤如下所示:

     

    三、部分MATLAB仿真

    matlab仿真程序如下所示:

    1. clc;
    2. clear;
    3. close all;
    4. warning off;
    5. addpath 'func\'
    6. RandStream.setDefaultStream(RandStream('mt19937ar','seed',1));
    7. %%
    8. load Character\Character_set.mat
    9. %显示论文fig1. character set used
    10. func_view_character;
    11. %**************************************************************************
    12. %%
    13. %显示论文fig2. Representation of 'A'
    14. A_Line = Representation_Character(A_ch);
    15. figure;
    16. stem(1:5*3,A_Line,'LineWidth',2);
    17. hold on
    18. stairs(1:5*3,A_Line,'r');
    19. axis([0,5*3+1,-0.25,1.25]);
    20. title('Fig.2. Representation of A');
    21. %**************************************************************************
    22. %% 以下是程序的第一部分,即4个字符的仿真
    23. %建立SNN神经网络模型
    24. Rm = 80;
    25. theta = 10; %10mv
    26. rs = 2; %2ms
    27. rm = 30; %30ms
    28. rmin = 2; %2ms
    29. rmax = 30; %30ms;
    30. lin = 0.3;
    31. lin_dec = 0.05;
    32. A1 = 0.1;
    33. A2 =-0.105;
    34. r1 = 1; %1ms
    35. r2 = 1; %1ms
    36. tstep = 0.2; %0.2ms;
    37. times = 200; %训练次数
    38. error = 1e-3;%训练目标误差
    39. vth = 7;
    40. %通过神经网络对A,B,C,D进行训练识别
    41. %对应论文fig.3. Output when each character is presented individually
    42. %随机产生初始的权值wij
    43. N_in = 15;
    44. w_initial = 0.5+0.5*rand(N_in,4);
    45. for i = 1:times
    46. w{i} = w_initial;
    47. end
    48. w0 = w_initial;
    49. dew = 0;
    50. wmax = 0;
    51. wmin = 0;
    52. det = zeros(N_in,4);
    53. tpre = 300*ones(N_in,times);
    54. tpost = 301*ones(N_in,times);
    55. Time2 = 4000;
    56. dt = 0.05;
    57. STIME = 24;
    58. %%
    59. %字符转换为电平
    60. A_Line = Representation_Character(A_ch);
    61. B_Line = Representation_Character(B_ch);
    62. C_Line = Representation_Character(C_ch);
    63. D_Line = Representation_Character(D_ch);
    64. Lines = [A_Line B_Line C_Line D_Line];
    65. for num = 1:4
    66. Dat = Lines(:,num);
    67. for ij = 1:1
    68. for i = 1:times
    69. w{i} = w_initial;
    70. end
    71. ind = 0;
    72. for i = 1:times
    73. i
    74. ind = ind + 1;
    75. %计算rd
    76. for n1 = 1:N_in
    77. rd(n1) = rmax - abs(w{i}(n1))*(rmax - rmin);
    78. end
    79. %计算Rd
    80. for n1 = 1:N_in
    81. Rd(n1) = (rd(n1)*theta/Rm) * (rm/rd(n1))^(rm/(rm-rd(n1)));
    82. end
    83. %计算delta t
    84. for n1 = 1:N_in
    85. if Dat(n1) == 1
    86. det(n1) = tpre(n1,i) - tpost(n1,i);
    87. else
    88. det(n1) = -(tpre(n1,i) - tpost(n1,i));
    89. end
    90. end
    91. %计算delta w
    92. for n1 = 1:N_in
    93. if det(n1) <= 0
    94. dew(n1) = A1*exp( det(n1)/r1);
    95. else
    96. dew(n1) = A2*exp(-det(n1)/r2);
    97. end
    98. if i > 1
    99. %计算权值更新
    100. if dew(n1) > 0
    101. w{i}(n1) = w{i-1}(n1) + lin*dew(n1)*(w{i-1}(n1));
    102. else
    103. w{i}(n1) = w{i-1}(n1) + lin*dew(n1)*(w{i-1}(n1));
    104. end
    105. end
    106. end
    107. %计算Id
    108. for n1 = 1:N_in
    109. Id{n1} = func_Id(Rd(n1),rd(n1),w{i}(n1),Dat(n1),Time2,dt,tpre(n1,i));
    110. end
    111. %计算Is
    112. Is = func_Is(N_in,rs,w{i},Dat,Time2,dt,tpre(n1,i));
    113. %计算u
    114. Um = func_um(N_in,w{i},Rm,Id,Is,rm,Time2,dt,Dat,tpre(n1,i),vth);
    115. %计算训练误差
    116. if i > 1
    117. err2(ind-1) = abs(norm(w{i}/max(w{i})) - norm(w{i-1}/max(w{i-1})));
    118. if abs(norm(w{i}/max(w{i}) - norm(w{i-1}/max(w{i-1})))) <= error
    119. break;
    120. end
    121. end
    122. end
    123. end
    124. Ws(:,num) = w{end}/max(w{end});
    125. clear Is Id Um w
    126. end
    127. %%
    128. %训练完之后进行测试
    129. %A测试
    130. %A测试
    131. UmA = [];
    132. for ij = 1:STIME
    133. ij
    134. for j = 1:4
    135. ind = 0;
    136. ind = ind + 1;
    137. %计算Id
    138. for n1 = 1:N_in
    139. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,1),Lines(n1,j),Time2,dt,tpre(n1,j));
    140. end
    141. %计算Is
    142. Is{j} = func_Is(N_in,rs,Ws(:,1),Lines(:,j),Time2,dt,tpre(n1,j));
    143. %计算u
    144. Um(j,:) = func_um(N_in,Ws(:,1),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
    145. end
    146. UmA = [UmA,Um];
    147. end
    148. figure;
    149. plot(1:Time2*STIME,UmA(1,:),'r');
    150. hold on
    151. plot(1:Time2*STIME,UmA(2,:),'g');
    152. hold on
    153. plot(1:Time2*STIME,UmA(3,:),'b');
    154. hold on
    155. plot(1:Time2*STIME,UmA(4,:),'c');
    156. hold off
    157. legend('Neuron1','Neuron2','Neuron3','Neuron4');
    158. axis([1,Time2*STIME,0,30]);
    159. clear Id Is Um w
    160. %B测试
    161. %B测试
    162. UmB = [];
    163. for ij = 1:STIME
    164. ij
    165. for j = 1:4
    166. ind = 0;
    167. ind = ind + 1;
    168. %计算Id
    169. for n1 = 1:N_in
    170. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,2),Lines(n1,j),Time2,dt,tpre(n1,j));
    171. end
    172. %计算Is
    173. Is{j} = func_Is(N_in,rs,Ws(:,2),Lines(:,j),Time2,dt,tpre(n1,j));
    174. %计算u
    175. Um(j,:) = func_um(N_in,Ws(:,2),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
    176. end
    177. UmB = [UmB,Um];
    178. end
    179. figure;
    180. plot(1:Time2*STIME,UmB(1,:),'r');
    181. hold on
    182. plot(1:Time2*STIME,UmB(2,:),'g');
    183. hold on
    184. plot(1:Time2*STIME,UmB(3,:),'b');
    185. hold on
    186. plot(1:Time2*STIME,UmB(4,:),'c');
    187. hold off
    188. legend('Neuron1','Neuron2','Neuron3','Neuron4');
    189. axis([1,Time2*STIME,0,30]);
    190. clear Id Is Um w
    191. %C测试
    192. %C测试
    193. UmC = [];
    194. for ij = 1:STIME
    195. ij
    196. for j = 1:4
    197. ind = 0;
    198. ind = ind + 1;
    199. %计算Id
    200. for n1 = 1:N_in
    201. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,3),Lines(n1,j),Time2,dt,tpre(n1,j));
    202. end
    203. %计算Is
    204. Is{j} = func_Is(N_in,rs,Ws(:,3),Lines(:,j),Time2,dt,tpre(n1,j));
    205. %计算u
    206. Um(j,:) = func_um(N_in,Ws(:,3),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
    207. end
    208. UmC = [UmC,Um];
    209. end
    210. figure;
    211. plot(1:Time2*STIME,UmC(1,:),'r');
    212. hold on
    213. plot(1:Time2*STIME,UmC(2,:),'g');
    214. hold on
    215. plot(1:Time2*STIME,UmC(3,:),'b');
    216. hold on
    217. plot(1:Time2*STIME,UmC(4,:),'c');
    218. hold off
    219. legend('Neuron1','Neuron2','Neuron3','Neuron4');
    220. axis([1,Time2*STIME,0,30]);
    221. clear Id Is Um w
    222. %D测试
    223. %D测试
    224. UmD = [];
    225. for ij = 1:STIME
    226. ij
    227. for j = 1:4
    228. ind = 0;
    229. ind = ind + 1;
    230. %计算Id
    231. for n1 = 1:N_in
    232. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,4),Lines(n1,j),Time2,dt,tpre(n1,j));
    233. end
    234. %计算Is
    235. Is{j} = func_Is(N_in,rs,Ws(:,4),Lines(:,j),Time2,dt,tpre(n1,j));
    236. %计算u
    237. Um(j,:) = func_um(N_in,Ws(:,4),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
    238. end
    239. UmD = [UmD,Um];
    240. end
    241. figure;
    242. plot(1:Time2*STIME,UmD(1,:),'r');
    243. hold on
    244. plot(1:Time2*STIME,UmD(2,:),'g');
    245. hold on
    246. plot(1:Time2*STIME,UmD(3,:),'b');
    247. hold on
    248. plot(1:Time2*STIME,UmD(4,:),'c');
    249. hold off
    250. legend('Neuron1','Neuron2','Neuron3','Neuron4');
    251. axis([1,Time2*STIME,0,30]);
    252. clear Id Is Um w
    253. %连续码流测试
    254. %连续码流测试
    255. Umss = [];
    256. for ij = 1:STIME
    257. ij
    258. if ij>=1 &ij <= 2
    259. W = Ws(:,3);
    260. end
    261. if ij>=3 &ij <= 5
    262. W = Ws(:,4);
    263. end
    264. if ij>=6 &ij <= 8
    265. W = Ws(:,1);
    266. end
    267. if ij>=9 &ij <= 12
    268. W = Ws(:,3);
    269. end
    270. if ij>=13 &ij <= 15
    271. W = Ws(:,1);
    272. end
    273. if ij>=16 &ij <= 18
    274. W = Ws(:,2);
    275. end
    276. if ij>=19 &ij <= 21
    277. W = Ws(:,3);
    278. end
    279. if ij>=22 &ij <= 24
    280. W = Ws(:,4);
    281. end
    282. for j = 1:4
    283. ind = 0;
    284. ind = ind + 1;
    285. %计算Id
    286. for n1 = 1:N_in
    287. Id{n1} = func_Id(Rd(n1),rd(n1),W(n1),Lines(n1,j),Time2,dt,tpre(n1,j));
    288. end
    289. %计算Is
    290. Is{j} = func_Is(N_in,rs,W,Lines(:,j),Time2,dt,tpre(n1,j));
    291. %计算u
    292. Um(j,:) = func_um(N_in,W,Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
    293. end
    294. Umss = [Umss,Um];
    295. end
    296. figure;
    297. plot(1:Time2*STIME,Umss(1,:),'r');
    298. hold on
    299. plot(1:Time2*STIME,Umss(2,:),'g');
    300. hold on
    301. plot(1:Time2*STIME,Umss(3,:),'b');
    302. hold on
    303. plot(1:Time2*STIME,Umss(4,:),'c');
    304. hold off
    305. legend('Neuron1','Neuron2','Neuron3','Neuron4');
    306. axis([1,Time2*STIME,0,30]);
    307. clear Id Is Um w
    308. %%
    309. %Fig.5.Weight distribution
    310. figure;
    311. subplot(121);
    312. bar3(w0(1:15,:),0.8,'r');hold on
    313. bar3(w0(1:12,:),0.8,'y');hold on
    314. bar3(w0(1:9,:) ,0.8,'g');hold on
    315. bar3(w0(1:6,:) ,0.8,'c');hold on
    316. bar3(w0(1:3,:) ,0.8,'b');hold on
    317. xlabel('Output Neruon');
    318. ylabel('Input Neruon');
    319. zlabel('Weight');
    320. title('Before training');
    321. axis([0,5,0,16,0,1.3]);
    322. view([-126,36]);
    323. subplot(122);
    324. bar3(Ws(1:15,:),0.8,'r');hold on
    325. bar3(Ws(1:12,:),0.8,'y');hold on
    326. bar3(Ws(1:9,:) ,0.8,'g');hold on
    327. bar3(Ws(1:6,:) ,0.8,'c');hold on
    328. bar3(Ws(1:3,:) ,0.8,'b');hold on
    329. xlabel('Output Neruon');
    330. ylabel('Input Neruon');
    331. zlabel('Weight');
    332. title('After training');
    333. axis([0,5,0,16,0,1.3]);
    334. view([-126,36]);

    四、仿真结论分析

         将SNN进行仿真,并得到类似论文中的仿真效果,具体的仿真结果如下图所示:

     

     

    上述就是实际的仿真效果图。

    五、参考文献

    [1] Gupta A ,  Long L N . Hebbian learning with winner take all for spiking neural networks[C]// International Joint Conference on Neural Networks. IEEE, 2009.A05-12

  • 相关阅读:
    MySQL学习笔记:索引2
    如何在一个pycharm项目中创建jupyter notebook文件,并切换到conda环境中
    java开放式实验室预约系统计算机毕业设计MyBatis+系统+LW文档+源码+调试部署
    【STM32单片机】贪吃蛇游戏设计
    Wonderware 实时库——一套可落地的传统工业实时库
    SpringBoot定时任务 - 什么是ElasticJob?如何集成ElasticJob实现分布式任务调度?
    做软件测试三年,薪资不到20K,今天,我提出了辞职…
    【CV】第 15 章:结合计算机视觉和 NLP 技术
    Ajax进阶
    接口测试--Postman变量
  • 原文地址:https://blog.csdn.net/ccsss22/article/details/126313976