可能原因 是cuda 版本导致的半精度浮点数计算出现nan的bug
- 解决办法 设置amp=False 就是不使用混合精度训练。
- 或者直接改用低版本的cuda和pytorch。cuda11.6 以下
直接有效
- 也有可能是学习率过高 降低学习率 。测试发现有些网络 学习率 要降低到0.000000001 可能解决问题
- 可能是batch 过小 逐步增大batch-size。实际发现 batch 逐步翻倍 有可能解决问题
设置amp=False之后还是存在问题 是因为yolov8库的问题 按以下修改
找到torch_utils.py 修改425行 去掉 half()