• torch lighting 设置多个优化器


    关键代码:

    class BodyVQModel(pl.LightningModule):
        def __init__(self, code_num=2048, embedding_dim=64, num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512):
            super().__init__()
            self.save_hyperparameters()
    
            self.automatic_optimization = False
    		
    		...
    	
    	def configure_optimizers(self):
            body_optimizer = torch.optim.AdamW(self.body_model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=args.weight_decay)
            hand_optimizer = torch.optim.AdamW(self.hand_model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-2)
            body_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(body_optimizer, mode='min', factor=0.1, patience=200, verbose=True)
            hand_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(hand_optimizer, mode='min', factor=0.1, patience=200, verbose=True)
            return ({"optimizer": body_optimizer, "lr_scheduler": {"scheduler": body_lr_scheduler, "monitor": "val/loss"}},
                    {"optimizer": hand_optimizer, "lr_scheduler": {"scheduler": hand_lr_scheduler, "monitor": "val/loss"}})
    
        def training_step(self, batch, batch_idx):
    
            opt1, opt2 = self.optimizers()
            opt1.zero_grad()
            opt2.zero_grad()
    
            loss_dict = {}
    
            loss_b, loss_dict = self._calc_loss(self.body_model, batch['motion'][:, :, upper_body_idx], loss_dict, prefix="train/body_", is_body=True)   # 上本身,(B, T=88, 39)
            loss_h, loss_dict = self._calc_loss(self.hand_model, batch['motion'][:, :, hands_idx], loss_dict, prefix="train/hand_", is_body=False)        # 手部,(B, T=88, 90)
            loss = loss_b + loss_h
            self.log_dict(loss_dict)
            self.log("train/loss", loss)
    
            rec_loss = loss_dict['train/body_rec_loss'] + loss_dict['train/hand_rec_loss']
            self.log("train/rec_loss", rec_loss)
    
            # return loss
    
            self.manual_backward(loss)
            # clip gradients
            self.clip_gradients(opt1, gradient_clip_val=10, gradient_clip_algorithm="norm")
            self.clip_gradients(opt2, gradient_clip_val=10, gradient_clip_algorithm="norm")
            opt1.step()
            opt2.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    ref:https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html

  • 相关阅读:
    Spring Security配置
    一起来打靶 02
    html所有标签和DOCTYPE的总结
    使用 Promise 来改写 JavaScript 的加载逻辑
    PySide6/PyQt6宝典:新手问题一网打尽!
    让你的相册变成私有云!Synology Photos 的公网访问功能指南
    如果没有Google这个靠山,Go 凭什么火?
    GUI-Guider软件使用
    JS进阶-变量和函数提升
    互联网的智算架构设计
  • 原文地址:https://blog.csdn.net/qq_42363032/article/details/133805361