关键代码:
class BodyVQModel(pl.LightningModule):
def __init__(self, code_num=2048, embedding_dim=64, num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False
...
def configure_optimizers(self):
body_optimizer = torch.optim.AdamW(self.body_model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=args.weight_decay)
hand_optimizer = torch.optim.AdamW(self.hand_model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-2)
body_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(body_optimizer, mode='min', factor=0.1, patience=200, verbose=True)
hand_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(hand_optimizer, mode='min', factor=0.1, patience=200, verbose=True)
return ({"optimizer": body_optimizer, "lr_scheduler": {"scheduler": body_lr_scheduler, "monitor": "val/loss"}},
{"optimizer": hand_optimizer, "lr_scheduler": {"scheduler": hand_lr_scheduler, "monitor": "val/loss"}})
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
opt1.zero_grad()
opt2.zero_grad()
loss_dict = {}
loss_b, loss_dict = self._calc_loss(self.body_model, batch['motion'][:, :, upper_body_idx], loss_dict, prefix="train/body_", is_body=True) # 上本身,(B, T=88, 39)
loss_h, loss_dict = self._calc_loss(self.hand_model, batch['motion'][:, :, hands_idx], loss_dict, prefix="train/hand_", is_body=False) # 手部,(B, T=88, 90)
loss = loss_b + loss_h
self.log_dict(loss_dict)
self.log("train/loss", loss)
rec_loss = loss_dict['train/body_rec_loss'] + loss_dict['train/hand_rec_loss']
self.log("train/rec_loss", rec_loss)
# return loss
self.manual_backward(loss)
# clip gradients
self.clip_gradients(opt1, gradient_clip_val=10, gradient_clip_algorithm="norm")
self.clip_gradients(opt2, gradient_clip_val=10, gradient_clip_algorithm="norm")
opt1.step()
opt2.step()
ref:https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html