• 【论文阅读】Distilling the Knowledge in a Neural Network


    论文原地址:http://arxiv.org/abs/1503.02531
    这篇文章是Hinton在2015发表的文章,是知识蒸馏(knowledge distillation)的开山之作。
    知识蒸馏是一种教师-学生模型(Teacher-Student Model),其核心思想是,将复杂模型(教师模型)中的知识转移到一个简单模型(学生模型)中,使学生模型拥有能够媲美教师模型准确率的能力。

    1介绍

    背景

    机器学习的实际应用中分为两步,模型训练和部署。
    在机器学习中,我们倾向于使用非常相似的模型进行训练(train)和部署(depolyment),尽管这两个阶段显然有不同的需求:
    在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去(受限于推理速度,资源设备)

    暗知识

    soft targethard target
    这里有在这里插入图片描述
    soft target的输出信息量要大于hard target的信息量。
    例如这里的手写数字2,它的标签就是2,也就是说这个数字是2的概率为1,是其他数字的概率为0,我们把这种标签就称作hard targets。(后面也称真实标签,标记,ground truth
    但是对于soft target,此时2的概率最高,3的概率次之,其他相同,则我们可以知道错认为3会比错认为其他有更大的可能。这可能是因为3与2之间有跟多的相似之处

    值得注意的是 对于Mnist数据集中总是产生具有信心的正确结果,很多非正确数字的概率都是非常小的。例如,对于正确的2来说,被分类为3的概率为 1 0 − 6 10^{-6} 106 ,被分类为7的概率为 1 0 − 9 10^{-9} 109 。在这种情况下soft target的帮助就很小,所以Caruana在其文章中使用log函数解决这一问题。本文利用蒸馏解决这个问题(蒸馏会使得target变得更加平滑)

    蒸馏温度T

    原softmax公式:
    q i = e x p ( z i ) ∑ j e x p ( z j ) q_{i}=\frac {exp({z_i})}{\sum_jexp(z_j)} qi=jexp(zj)exp(zi)

    加入温度系数T后的公式:

    q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_{i}=\frac {exp({z_i}/{T})}{\sum_jexp(z_j/T)} qi=jexp(zj/T)exp(zi/T)

    在加入T之后的softmax的概率分布更加平滑,作为soft target时简易(student)网络能学到更多东西。
    在这里插入图片描述

    步骤

    在这里插入图片描述

    1.教师模型的训练,这一步通过带有标签的数据集对复杂模型进行训练。最终我们得到了一个已经训练好的教师网络。

    2.通过教师模型对学生模型的训练,这里我们就引入了超参数T改进后的Softmax函数
    3.学生模型单独做预测
    在这里插入图片描述
    在这里插入图片描述

    具体来说
    ①在将大模型生成的logits送入输出层softmax之前,除以一个比较大的值T=
    t,得到更加平滑的概率分布作为soft target(软目标、软标签)。这个概率分布在0~1之间,取值分布缓和

    tip:T数值越大,分布越缓和;而T数值减小,容易放大错误分类的概率,
    引入不必要的噪声。针对较困难的分类或检测任务,T通常取1,
    确保教师网络中正确预测的贡献。
    
    • 1
    • 2
    • 3

    ②学生模型会有两个softmax函数,一个T=t,一个T=1

    学生模型经过T=t的softmax函数后会产生soft prediction,此时soft target和soft prediction计算交叉熵损失,这个损失称为蒸馏损失(distillation loss)。通过其来减小soft target和soft prediction之间的差别。

    学生模型经过T=1的softmax函数后产生的hard prediction与硬目标hard target
    计算交叉熵损失,这个损失称为学生损失(student loss),这个就是传统的损失计算。 这里的硬目标则是样本的真实标注,可以用One-hot矢量表示。

    总损失Total loss设计为软目标与硬目标所对应的交叉熵的加权平均,公式如下
    L o s s = λ L s + ( 1 − λ ) L h Loss=λL_{s}+(1−λ)L_{h} Loss=λLs+(1λ)Lh
    其中, L s L_{s} Ls表示蒸馏损失, L h L_{h} Lh表示蒸馏损失, ​λ表示系数,

    软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,通常初始Ls权重更高,之后慢慢减小这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。

    作用:
    蒸馏损失:保证小模型和大模型的结果一致性
    学生损失:保证小模型的结果与实际类别标签一致性

    参考

    Distilling the Knowledge in a Neural Network[论文阅读笔记]
    如何理解soft target这一做法?
    【知识蒸馏】Distilling the Knowledge in a Neural Network
    【论文阅读】Distilling the Knowledge in a Neural Network

  • 相关阅读:
    tensorflow数据类型转换
    【设计模式】原型模式:猴头,我叫你一声你敢答应吗?
    Redis不止能存储字符串,还有List、Set、Hash、Zset,用对了能给你带来哪些优势?
    Python - GFPGAN + MoviePy 提高人物视频画质
    记 : CTF2023羊城杯 - Reverse 方向 Blast 题目复现and学习记录
    大数据项目之电商数仓、业务数据介绍、MySQL安装、更改MySQL密码策略
    2023去水印小程序源码修复版-前端后端内置接口+第三方接口
    Nacos安装指南(Windows环境)
    Redis - String|Hash|List|Set|Zset数据类型的基本操作和使用场景
    Springboot企业的信息管理系统5qs0a计算机毕业设计-课程设计-期末作业-毕设程序代做
  • 原文地址:https://blog.csdn.net/weixin_45592399/article/details/125990939