• PyTorch Lightning - LightningModule 训练逻辑 (training_step) 异常处理 try-except


    欢迎关注我的CSDN:https://spike.blog.csdn.net/
    本文地址:https://spike.blog.csdn.net/article/details/133673820

    LightningModule

    在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 try-except 及时捕获错误。即 当错误发生时,在 training_step 异常返回 None,同时,on_before_zero_grad 也需要进行异常处理,处理 training_step 的异常返回 None。

    同样的,validation_step 也可以这样处理。

    源码如下:

    class MyObject(pl.LightningModule):
    	def __init__(self, config, args):
    		# ...
    		
    	def training_step_wrapper(self, batch, batch_idx, log_interval=10):
    		# train key process
    		
    	def training_step(self, batch, batch_idx, log_interval=10):
            """
            typically, each step costs 50 seconds
            参考: https://github.com/Lightning-AI/lightning/pull/3566
            """
            try:
                res = self.training_step_wrapper(batch, batch_idx, log_interval)
                return res
            except Exception as e:
                logger.info(f"[CL] training_step, exception: {e}")
                return None
                
    	def on_before_zero_grad(self, *args, **kwargs):
            try:
                self.ema.update(self.model)
            except Exception as e:
                # 支持 training_step return None
                logger.info(f"[CL] on_before_zero_grad, exception: {e}")
                return
                
    	def validation_step_wrapper(self, batch, batch_idx):
            # val key process
    
        def validation_step(self, batch, batch_idx):
            try:
                self.validation_step_wrapper(batch, batch_idx)
            except Exception as e:
                logger.info(f"[CL] validation_step, exception: {e}")
                return
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    常见错误如下

    数组越界:

    index 0 is out of bounds for dimension 0 with size 0

    字典错误字段:

    num_res = int(np_example["seq_length"])
    KeyError: 'seq_length'

    计算输入数值为空:

    V, _, W = torch.linalg.svd(C)

    free()异常:

    free(): invalid next size (fast)

    munmap_chunk() 空指针:

    munmap_chunk(): invalid pointer

  • 相关阅读:
    linux文件查看和文件查找
    在线负载离线负载与在线算法离线算法
    SwiftUI Swift 多个 sheet
    STM32智能物流机器人系统教程
    python读取CSV格式文件,遇到的问题20231007
    嵌入式学习笔记(1)ARM的编程模式和7种工作模式
    Flume集成Kafka
    clion2020 中文版安装
    异步同步调用
    17-js原型链
  • 原文地址:https://blog.csdn.net/u012515223/article/details/133673820