• 竞赛trick-AWP对抗训练的即插即用实现


    1、对抗训练的简单概括

    对抗训练就是使用对抗样本去训练模型,从而通过对原始训练数据添加噪声便得到了对抗样本。在竞赛中常在BERT的embedding阶段进行扰动,常使用的对抗训练有pgd,fgm,freelb等。本文主要记录awp对抗训练即插即用实现。(缘故:在刚结束的腾讯微信大数据挑战赛-多模态短视频分类竞赛中有大幅度提升)。

    2、实现

    2.1 awp实现

    class AWP:
        def __init__(self, model, optimizer, adv_param="weight", adv_lr=1, adv_eps=0.0001):
            self.model = model
            self.optimizer = optimizer
            self.adv_param = adv_param
            self.adv_lr = adv_lr
            self.adv_eps = adv_eps
            self.backup = {}
            self.backup_eps = {}
    
        def attack_backward(self, inputs, labels):
            if self.adv_lr == 0:
                return
            self._save()
            self._attack_step()
    
            y_preds = self.model(inputs)
    
            adv_loss = self.criterion(y_preds, labels)
            self.optimizer.zero_grad()
            return adv_loss
    
        def _attack_step(self):
            e = 1e-6
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None and self.adv_param in name:
                    norm1 = torch.norm(param.grad)
                    norm2 = torch.norm(param.data.detach())
                    if norm1 != 0 and not torch.isnan(norm1):
                        # 在损失函数之前获得梯度
                        r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                        param.data.add_(r_at)
                        param.data = torch.min(
                            torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
                        )
    
        def _save(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None and self.adv_param in name:
                    if name not in self.backup:
                        self.backup[name] = param.data.clone()
                        grad_eps = self.adv_eps * param.abs().detach()
                        self.backup_eps[name] = (
                            self.backup[name] - grad_eps,
                            self.backup[name] + grad_eps,
                        )
    
        def _restore(self,):
            for name, param in self.model.named_parameters():
                if name in self.backup:
                    param.data = self.backup[name]
            self.backup = {}
            self.backup_eps = {}
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    2.2 经典训练模式

    for step, batch in enumerate(train_loader):
        inputs, labels = batch
        
        # 将模型的参数梯度初始化为0
        optimizer.zero_grad()
        
        # forward + backward + optimize
        predicts = model(inputs)          # 前向传播计算预测值
        loss = loss_fn(predicts, labels)  # 计算当前损失
        loss.backward()       # 反向传播计算梯度
    	loss.backward()
        optimizer.step()                  # 更新所有参数 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    2.3 awp训练模式

    # 初始化AWP
    awp = AWP(model, loss_fn, optimizer, adv_lr=awp_lr, adv_eps=awp_eps)
    
    for step, batch in enumerate(train_loader):
        inputs, labels = batch
        
        # 将模型的参数梯度初始化为0
        optimizer.zero_grad()
        
        # forward + backward + optimize
        predicts = model(inputs)          # 前向传播计算预测值
        loss = loss_fn(predicts, labels)  # 计算当前损失
        loss.backward()       # 反向传播计算梯度
        # 指定从第几个epoch开启awp,一般先让模型学习到一定程度之后
        if awp_start >= epoch:
            loss = awp.attack_backward(inputs, labels)
            loss.backward()
            awp._restore()                    # 恢复到awp之前的model
        optimizer.step()                  # 更新所有参数 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    2.4 超参数

    • adv_param (str): 要攻击的layer name,一般攻击第一层或者全部weight参数效果较好
    • adv_lr (float): 主要调整的参数,表示攻击步长,该参数相对难调节,如果只攻击第一层embedding,一般用1比较好,全部参数用0.1比较好
    • adv_eps (float): 主要调整参数,表示参数扰动最大幅度限制,范围(0 ~ +∞),一般设置(0,1)之间相对合理一点
    • start_epoch (int): (0 ~ +∞)什么时候开始扰动,默认是0,如果效果不好可以调节值模型收敛一半的时候再开始攻击

    参考文献

  • 相关阅读:
    golang 中 channel 的详细使用、使用注意事项及死锁分析
    【美团3.18校招真题2】
    SpringBoot —— 整合RabbitMQ常见问题及解决方案
    Matlab reconstruct signal form sample points, convulsion
    【Java】反射是什么?
    数据结构与算法—双链表
    mybaits-plus lambdaQuery() 和 lambdaUpdate() 比较常见的使用方法
    Plato Farm有望通过Elephant Swap,进一步向外拓展生态
    redis基础知识总结——数据类型(字符串,列表,集合,哈希,集合)
    带你深入了解git
  • 原文地址:https://blog.csdn.net/yjh_SE007/article/details/126933154