网络压缩方法主要4种:1.Network pruning
(网络剪枝) 2.Sparse representation
稀疏表示 3.Bits precision 可以利用低比特的表示,甚至二值化的表示 4. Konwledge distillation
知识蒸馏
知识蒸馏最早的一篇论文:Distilling the Knowledge in a Netrual Network, 论文在2015年由深度学习领域的大牛Geffrey Hinton
和Google Jeff Dean等人提出。
知识蒸馏类似于迁移学习,可以看成是一种知识的迁移。
由教师网络模型(Teacher mode)将知识迁移给学生(Student model)模型,通俗的讲就是用教师网络指导学生网络的学习。
知识蒸馏(Knowledge distillation)
是提升网络性能的方法,通过一个教师网络指导学生网络的学习
,教师网络是一个较大的模型,学生网络是一个较小的模型,通过教师网络学习到的知识迁移到学生网络上。训练学生网络的时候可以得到教师网络蒸馏到的 soft label知识,同时学生网络也可以从ground truth来进行预测学习。
学生网络通过教师网络的网络训练得到的soft label和ground truth进行联合学习,蒸馏的关键是:教师网络训练得到一个soft label
以手写数字为例,教师网络对数字1
的预测标签为"1":0.7,"7":2,"9":0.1
,这里1的预测概率最大为0.7
是正确的分类,但是标签"7","9"
的预测概率也能提供一些信息,就是说7,9和预测标签1还是有某种预测的相似度的。
如果把这个信息也教会学生网络,学生网络就可以了解到这种类别之间的相似度,可以看作为学习到了教师网络中隐藏的知识,对于学生网络的分类是有帮助的。
蒸馏的时候一般都需要进行升温操作,以分类网络为例,需要改进softmax
,除以T
左边是正常的softmax
计算,右边引入一个超参T
来计算softmax,即把
y
i
y_i
yi都除以T
再进行softmax
计算。结果会有什么不同呢?
假设左边的y1,y2,y3
的分类结果,分别是100,10,1
,没有进行升温操作,带入softmax
之后
y
1
′
y_1^{'}
y1′接近1
,
y
2
′
y_2^{'}
y2′和
y
3
′
y_3^{'}
y3′都接近于0
。经过升温操作之后,
y
1
′
=
0.56
,
y
2
′
=
0.23
,
y
3
′
=
0.21
y_1^{'}=0.56 ,y_2^{'}=0.23,y_3^{'}=0.21
y1′=0.56,y2′=0.23,y3′=0.21。此时
y
1
′
,
y
2
′
,
y
3
′
y_1^{'},y_2^{'},y_3^{'}
y1′,y2′,y3′之间的差异就没有那么大了,也就是说能够把类别之间隐藏的关系给暴露出来
,所以做蒸馏操作一般都需要除以T
来进行升温操作。
如图所示,对首先数1
识别进一步说明:
1
和7
比较相似,在没有升温T
之前,预测标签1
和7
经过softmax输出的概率分别为0.99
和0.01
,1和7
之间的差异是比较大的,网络没有学习到1与7比较相似这种隐藏信息。T=5
的话,得到的最终概率为0.62
和0.28
,相当于把预测值进行了软化,这种预测值称为softened predictions
在分类网络中知识蒸馏的Loss计算
softmax
要进行升温,升温后的预测结果我们称为软标签(soft label)
softmax
的时候也进行升温,在预测的时候得到软预测(soft predictions),然后对soft label
和soft predictions
计算损失函数,称为distillation loss
,让学生网络的预测结果接近教师网络;softmax
的时候不进行升温T=1
,此时预测的结果叫做hard prediction
。然后和hard label
也就是ground truth
直接计算损失,称为student loss
Loss
,比如与教师网络通过MSE
损失,学生网络与ground truth通过 cross entropy
损失,Loss的公式可表示如下:
对目标检测知识蒸馏稍微复杂点,目标检测既有物体分类又有物体边界框的预测,学生网络可以从教师网络中进行物体分类的学习,也可以进行边界框的学习,同时也可以从ground truth
进行学习。
在论文Object scaled Distllation
介绍了知识蒸馏相关的内容。
训练的损失函数由3部分构成
三种损失分别为:
f
b
b
C
o
m
b
f_{bb}^{Comb}
fbbComb 边界框的损失,
f
c
l
C
o
m
b
f_{cl}^{Comb}
fclComb 分类的损失,
f
o
b
j
C
o
m
b
f_{obj}^{Comb}
fobjComb 目标置信度的损失。在目标检测的知识蒸馏中,每种损失都包括:Detection loss
和Distillation loss
,Detection loss
是学生网络和ground truth
之间的损失;Distillation loss
是学生网络和教师网络之间的损失。
object scaled
,在蒸馏损失部分引入了
o
i
T
o_i^{T}
oiT加权因子,
o
i
T
o_i^{T}
oiT为0~1之间的值,object scaled
越接近1表示越可能是目标,越接近0表示越可能是背景,如果是目标比较低,乘以一个比较小的值,对模型的训练是有帮助的。object scaled
,在蒸馏损失部分引入了
o
i
T
o_i^{T}
oiT加权因子
我们使用Yolov5m
较大的网络作为教师网络,指导学生网络YOLOv5s
的学习,对
蒸馏之后,学生网络YOLOv5s
的Precision由蒸馏前的0.891
变为蒸馏后的0.924
,效果提升了3
个点。 mAP0.5
由蒸馏前的0.898
,变为蒸馏后的0.935
将近提升了4
个点。
这种知识蒸馏方法,通过训练过程做一些工作,就能达到预测性能提升,几乎不需要任何代价就能提升网络的性能。
后续博客将会针对Yolov5 目标检测之知识蒸馏实战
进行详细讲解,项目内容包括: