目录
在接触一个新技术之前,肯定是因为遇到了新的难题,但这可以促使我们前进。
损失函数当中加入OHEM思想——图像分割损失函数OhemCELoss
之前做个一个目标检测任务,每个类别之间的数据量差距较大,有明显的类别不均衡现象(当样本比例大于4:1时)。
解决类别不平衡问题现以有了较多的可行解决方案:
本篇博客主要会记录在研究OHEM技术时的一些总结。
在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个标注框(真实框),也就是说绝大部分检测框与真实框没有很大的交集,一般计算的IOU大于设置阈值时认为是正样本,小于设置阈值时是负样本。
但是这样选出来的框不一定是最容易错的框。
我们通常在生成检测框负样本中选出容易预测错误的(当成正样本的),作为新的数据集进行参与训练。
即hard Negative Mining(困难样本挖掘)。
思想:
你不会把所有错题都放到错题集中,只对当中最容易的错放入。
实现思想:
迭代地交替训练,用样本集更新模型,然后再固定模型,来选择分辨错的目标框并加入到样本集中继续训练。
缺点:
hard Negative Mining(困难样本挖掘)需要在不断的训练当中冻结参数、预测选出hard Negative再放入训的训练集,这大幅度的增加了工作量,加大了模型训练的时间。
注:一般使用 SVM 分类器才能使用此方法(SVM 分类器和 Hard Negative Mining Method 交替训练)
前言:
hard Negative Mining(困难样本挖掘)思想值得我们去使用和学习,但是我们试图在不影响效果的前提下去提高模型的迭代训练速度。故我们提出了OHEM(在线难例挖掘)。
论文:
1604.03540.pdf (arxiv.org)https://arxiv.org/pdf/1604.03540.pdfOHEM(在线难例挖掘)流程概述:
1、进行一次的前向传播,获得每个Region proposal单独的损失值。
2、对每个Region proposal进行NMS计算。
3、对剩下的Region proposal按照损失值进行排序,然后选取损失最大的前一部分Region当做输入再次输入分类回归网络,对于训练多次loss还较高的我们可以认为其是困难样本。
4、将困难样本输入图中的(b)模块,(b)模块是(a)模块的复制版,(b)模块是用来反向传播的部分,然后吧更新的参数共享到(a)部分。
注:所谓的线上挖掘,就是先计算loss→筛选→得到困难负样本。
前言:
其实在mmdetection当中,已经封装好了OHEM的代码,但是大家可能都不知道他在哪,这里我给大家找一下他的位置。
彩蛋:
如何在mmdetection查找自己想要的东西(类或者类的调用等)
前言:
虽然他在目标检测当中被提出,但是不仅仅是目标检测问题,其他问题都会出现类别不平衡的问题,我们试图把他应用到其他方向当中(例如语义分割)
代码实现:
- class OhemCELoss(nn.Module):
- """
- Online hard example mining cross-entropy loss:在线难样本挖掘
- if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
- 如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
- 那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
- 否则,计算前 n_min 个损失:loss = loss[:self.n_min]
- """
- def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
- super(OhemCELoss, self).__init__()
- self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() # 将输入的概率 转换为loss值
- self.n_min = n_min
- self.ignore_lb = ignore_lb
- self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') #交叉熵
-
- def forward(self, logits, labels):
- N, C, H, W = logits.size()
- loss = self.criteria(logits, labels).view(-1)
- loss, _ = torch.sort(loss, descending=True) # 排序
- if loss[self.n_min] > self.thresh: # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
- loss = loss[loss>self.thresh]
- else:
- loss = loss[:self.n_min]
- return torch.mean(loss)
(32条消息) 【每日一网】Day18:OHEM简单理解_陈子文好帅的博客-CSDN博客_ohem