• InstDisc 代码解读


    目录

    Unsupervised Feature Learning via Non-Parametric Instance Discrimination 代码解读

    0. 概览

    1. lemniscate

    1.1 lemniscate 的定义

    1.2 NCEAverage

    1.3 训练时,如何使用这个 NCEAverage 对象

    1.4 NCEFunction

    2. criterion


    Unsupervised Feature Learning via Non-Parametric Instance Discrimination 代码解读

    论文下载地址:https://arxiv.org/pdf/1805.01978.pdf

    代码地址:GitHub - zhirongw/lemniscate.pytorch: Unsupervised Feature Learning via Non-parametric Instance Discrimination


    0. 概览

    这里将解读代码最核心的部分:计算 loss 更新 memory bank 的部分。 

    在 main.py 文件中找到前向计算 loss 部分的代码:

    lemniscate.pytorch/main.py at master · zhirongw/lemniscate.pytorch · GitHub

    1. # compute output
    2. feature = model(input)
    3. output = lemniscate(feature, index)
    4. loss = criterion(output, index) / args.iter_size


    1. lemniscate

    1.1 lemniscate 的定义

    lemniscate.pytorch/main.py at master · zhirongw/lemniscate.pytorch · GitHub

    1. # define lemniscate and loss function (criterion)
    2. ndata = train_dataset.__len__() # ndata:整个数据集的长度,也就是 memory bank 的长度
    3. if args.nce_k > 0: # args.nce_k:负样本的个数,默认是采样 4096 个负样本
    4. lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda()
    5. # args.low_dim:memory bank 里存特征的维度 128(看图)
    6. # args.nce_t:计算 NCE Loss 里的温度系数
    7. # args.nce_m:动量更新 memory bank 里特征的 momentum
    8. criterion = NCECriterion(ndata).cuda()
    9. else:
    10. lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda()
    11. criterion = nn.CrossEntropyLoss().cuda()

    所以,lemniscate 是一个 NCEAverage 对象。

    1.2 NCEAverage

    https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/NCEAverage.py#L72

    1. class NCEAverage(nn.Module):
    2. def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, Z=None):
    3. super(NCEAverage, self).__init__()
    4. self.nLem = outputSize # 传进来的是 ndata:整个数据集的长度,也就是 memory bank 的长度
    5. self.unigrams = torch.ones(self.nLem) # 创建一个形状为 (ndata, ) 的张量,里面全是 1
    6. self.multinomial = AliasMethod(self.unigrams) # AliasMethod 在这里是用于:随机采样负样本
    7. self.multinomial.cuda()
    8. self.K = K # 随机采样负样本的数量
    9. self.register_buffer('params',torch.tensor([K, T, -1, momentum])); # 用 params 保存参数 K, T, momentum 的值,用于后面计算 NCE loss
    10. stdv = 1. / math.sqrt(inputSize/3)
    11. self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv)) # 随机初始化 memory bank
    12. def forward(self, x, y):
    13. batchSize = x.size(0)
    14. idx = self.multinomial.draw(batchSize * (self.K+1)).view(batchSize, -1) # 用 AliasMethod 为 batch 里的每个样本都采样 4096 个负样本的 idx
    15. out = NCEFunction.apply(x, y, self.memory, idx, self.params)
    16. return out

    总结,lemniscate 这个 NCEAverage 对象在初始化时定义了:用于随机采样负样本的 AliasMethod 对象,随机初始化的 memory bank

    1.3 训练时,如何使用这个 NCEAverage 对象

    output = lemniscate(feature, index)

    可以看到,因为 lemniscate 是 NCEAverage 对象,所以这里自动调用调用了 NCEAverage 的 forward 方法,并传入这个 batch 的图片经过 CNN 提取的特征 feature,以及这个 batch 的图片在数据集的 index。

    在 NCEAverage 中 forward 方法里,做了几件事:

    • 用 AliasMethod 为 batch 里的每个样本都采样 4096 个负样本的 idx
    • 新建 NCEFunction 对象 out,计算输入特征 x 属于 memory bank 中第 i 个样本的概率;由于 NCEFunction 继承 torch.autograd,所以在模型 backward 更新参数的时候,会调用 NCEFunction 里的 backward 函数动量更新 memory bank

    1.4 NCEFunction

    https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/NCEAverage.py#L7

    forward 函数计算输入特征 x 属于 memory bank 中第 i 个样本的概率,对应论文里的公式:

    1. class NCEFunction(Function):
    2. @staticmethod
    3. def forward(self, x, y, memory, idx, params):
    4. K = int(params[0].item())
    5. T = params[1].item()
    6. Z = params[2].item()
    7. momentum = params[3].item()
    8. batchSize = x.size(0)
    9. outputSize = memory.size(0)
    10. inputSize = memory.size(1)
    11. # sample positives & negatives
    12. idx.select(1,0).copy_(y.data)
    13. # sample correspoinding weights
    14. weight = torch.index_select(memory, 0, idx.view(-1))
    15. weight.resize_(batchSize, K+1, inputSize)
    16. # inner product
    17. out = torch.bmm(weight, x.data.resize_(batchSize, inputSize, 1))
    18. out.div_(T).exp_() # batchSize * self.K+1
    19. x.data.resize_(batchSize, inputSize)
    20. if Z < 0:
    21. params[2] = out.mean() * outputSize
    22. Z = params[2].item()
    23. print("normalization constant Z is set to {:.1f}".format(Z))
    24. out.div_(Z).resize_(batchSize, K+1)
    25. self.save_for_backward(x, memory, y, weight, out, params) # 保存变量,在 backward 的时候再更新 memory bank
    26. return out

    backward 函数里动量更新 memory bank

    1. @staticmethod
    2. def backward(self, gradOutput):
    3. x, memory, y, weight, out, params = self.saved_tensors
    4. K = int(params[0].item())
    5. T = params[1].item()
    6. Z = params[2].item()
    7. momentum = params[3].item()
    8. batchSize = gradOutput.size(0)
    9. # gradients d Pm / d linear = exp(linear) / Z
    10. gradOutput.data.mul_(out.data)
    11. # add temperature
    12. gradOutput.data.div_(T)
    13. gradOutput.data.resize_(batchSize, 1, K+1)
    14. # gradient of linear
    15. gradInput = torch.bmm(gradOutput.data, weight)
    16. gradInput.resize_as_(x)
    17. # update the non-parametric data: # 动量更新 memory bank
    18. weight_pos = weight.select(1, 0).resize_as_(x)
    19. weight_pos.mul_(momentum)
    20. weight_pos.add_(torch.mul(x.data, 1-momentum))
    21. w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
    22. updated_weight = weight_pos.div(w_norm)
    23. memory.index_copy_(0, y, updated_weight)
    24. return gradInput, None, None, None, None


    2. criterion

    计算 NCE Loss:

    https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/NCECriterion.py#L6

    1. class NCECriterion(nn.Module):
    2. def __init__(self, nLem):
    3. super(NCECriterion, self).__init__()
    4. self.nLem = nLem
    5. def forward(self, x, targets):
    6. batchSize = x.size(0)
    7. K = x.size(1)-1
    8. Pnt = 1 / float(self.nLem)
    9. Pns = 1 / float(self.nLem)
    10. # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
    11. Pmt = x.select(1,0)
    12. Pmt_div = Pmt.add(K * Pnt + eps)
    13. lnPmt = torch.div(Pmt, Pmt_div)
    14. # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
    15. Pon_div = x.narrow(1,1,K).add(K * Pns + eps)
    16. Pon = Pon_div.clone().fill_(K * Pns)
    17. lnPon = torch.div(Pon, Pon_div)
    18. # equation 6 in ref. A
    19. lnPmt.log_()
    20. lnPon.log_()
    21. lnPmtsum = lnPmt.sum(0)
    22. lnPonsum = lnPon.view(-1, 1).sum(0)
    23. loss = - (lnPmtsum + lnPonsum) / batchSize
    24. return

  • 相关阅读:
    SpringMVC异常处理和自定义拦截器
    html和css创建一个简单的网页
    运放-运算放大器经典应用电路大全-应用电路大全-20种经典电路
    Leedcode 每日一题: 2760. 最长奇偶子数组
    P2E-Higtstreet
    Spring Boot 到底是单线程还是多线程
    C++ 中的 Pimpl 惯用法
    Java也能做OCR!SpringBoot 整合 Tess4J 实现图片文字识别
    Linux——进程地址空间
    Java-多线程
  • 原文地址:https://blog.csdn.net/qq_36627158/article/details/127714539