深度神经网络(DNN)在各种任务中取得了前所未有的成功,但是,这些模型性能直接取决于它们的超参数的设置。在实践中,优化超参数仍是设计深度神经网络的一大障碍。在这项工作中,我们建议使用粒子群优化算法(PSO)来选择和优化模型参数。在MNIST数据集上的实验结果显示:通过PSO优化的CNN模型可以得到不错的分类精度,此外,PSO 还可以提高现有模型结构的性能,PSO是自动化超参数选择和有效利用计算资源的有效技术。
PSO是粒子群优化算法(Particle Swarm Optimization)的英文缩写,是一种基于种群的随机优化技术,由Eberhart和Kennedy于1995年提出。粒子群算法模仿昆虫、兽群、鸟群和鱼群等的群集行为,这些群体按照一种合作的方式寻找食物,群体中的每个成员通过学习它自身的经验和其他成员的经验来不断改变其搜索模式。
- def func(x):
- n,sf,sp,l = x[0],x[1],x[2],x[3]
-
- model = Sequential()
- model.add(Conv2D(32,kernel_size=(3, 3),
- activation='relu',
- input_shape=input_shape))
- model.add(Conv2D(64, (3, 3), activation='relu'))
- model.add(Conv2D(n, (sf, sf), activation='relu'))#待优化的参数
- model.add(MaxPooling2D(pool_size=(sp, sp),strides=(l,l)))#待优化的参数
- model.add(Flatten())
- model.add(Dense(num_classes, activation='softmax'))
- model.compile(loss=keras.losses.categorical_crossentropy,
- optimizer=keras.optimizers.Adam(),
- metrics=['accuracy'])
本文选取卷积核数量、大小,池化核大小作为待优化超参数。其实现算法的伪代码如下所示:
(1)对数据集进行预处理,加载图像化数据;
(2)选取训练集作为网络输入,测试集作为模型测试输入;
(3)初始化粒子群;
(4)计算每个粒子的适应度值,并将初始适应度值作为每个粒子的当前最优值;
(5)将粒子中最好的适应度值作为全局最优值;
(6)根据式 4 和式 5 更新粒子的位置和速度;
(7)将每个微粒的适应度值与其历史最优进行比较,如果较好,则进行替换;
(8)将每个微粒的适应度值与历史全局最优进行比较,如果较好,则进行替换;
(9)如未达到终止条件(通常达到预定的最大迭代次数),则返回步骤 6)继续执行;
(10)粒子最优解应用于 CNN 超参数设置,基于该网络进行图像分类识别。可以设置最大迭代次数为 30,粒子群大小为 10,w=0.73, 1 c = 2 c =1.45。
最后,该过程实现代码如下所示:
- pso = Pso(swarmsize=4,maxiter=14)
- # n,sf,sp,l
- bp,value = pso.run(func,[1,2,2,2],[16,8,4,4])
- v = func(bp);
- print('Test loss:', bp)
- print('Test accuracy:', value,v)
完整代码:找QQ525894654索要。
加载数据及训练过程:
实验结果:
模型 | 准确率 |
CNN | 87.33% |
PSO-CNN | 91.69% |
可以看出,PSO-CNN 模型的识别准确率均优于 CNN 模型,具有良好的鲁棒性。
针 对CNN 算法的收敛速度较慢、过 拟合 等问题, 文章提出一种基于PSO和 CNN 模型的图像分类方法,在分析完CNN各超参数对其性能的影响后,引入 PSO 算法进行寻优以增强CNN网络模型的特征提取能力,模型将CNN算法中需要训练的参数作为粒子进行优化,将 更 新 的 参 数 用 于CNN 算 法 的 前 向 传播,调整网络连接权矩阵迭代,直到误差收敛,停止算法,以达到最终的模型优化。
参考文献
1. Lorenzo P R , Nalepa J , Kawulok M , et al. Particle swarm optimization for hyper-parameter selection in deep neural networks[C]// the Genetic and Evolutionary Computation Conference. 2017.
2.王金哲, 王泽儒, 王红梅. 基于PSO算法与Dropout的改进CNN算法[J]. 长春工业大学学报:自然科学版, 2019(1):5.