yolov5优化模型时,一般需要继续标注一些检测错误的图片,将其标为xml数据。以下是根据训练好的模型自动标注xml数据的python代码:
注意:代码中包含了本人的yolov5的测试过程,测试过程可以自己根据yolov5的测试文件自行修改,只是测试返回的类格式为:
[["water",[15,20,30,40]],["red",[12,13,14,15]]]
二维数组表示测试的类为water和red,其中后面的数字表示类的坐标:[top,left,bottom,right],表示上、左、下、右4个坐标。
-
- import os
- import cv2
- from PIL import Image
-
- from yolo import YOLO
-
-
- #1.预测类,获得字符串
- class Predict():
-
- def a(self, img_path,save_path,img_name):
-
- image = Image.open(img_path)
-
- r_image, pred = yolo.detect_image(image, pred_class, img_name)
-
- if not os.path.exists(dir_save_path):
- os.makedirs(dir_save_path)
-
- r_image.save(save_path, quality=95, subsampling=0)
-
- return pred
-
-
- #2.写入xml文件
- def img_xml(img_path,xml_path,img_name,pred):
-
- if len(pred) != 0:
-
- #1.读取图片(xml需要写入图片的长宽高)
- img = cv2.imread(img_path)
-
- #2.写入xml文件
- #(1)写入文件头部
- files_path=img_path.split("\\")[-2]
- print("..:",files_path)
-
- xml_file = open((xml_path + img_name + '.xml'), 'w')
- xml_file.write('
\n' ) - xml_file.write('
' +files_path+ '\n') - xml_file.write('
' + img_name + '.jpg' + '\n') - xml_file.write('
' + img_path +'\n') -
- xml_file.write(' )
- xml_file.write('
Unknown \n') - xml_file.write(' \n')
-
- #(2)写入图片的长宽高信息
- xml_file.write('
\n' ) - xml_file.write('
' +str(img.shape[1])+'\n') - xml_file.write('
' + str(img.shape[0]) + '\n') - xml_file.write('
' + str(img.shape[2]) + '\n') - xml_file.write(' \n')
-
- xml_file.write('
0 \n') -
- #3.写入字符串信息:[["water",[15,20,30,40]],["red",[12,13,14,15]]]
- #if len(shuzu)!=0:
- for item in pred:
- xml_file.write(' )
- xml_file.write('
' + str(item[0]) + '\n') - xml_file.write('
Unspecified \n') - xml_file.write('
0 \n') - xml_file.write('
0 \n') - xml_file.write('
\n' ) -
- #写入字符串信息
- #[top, left, bottom, right]
- xml_file.write('
' + str(item[1][1]) + '\n') - xml_file.write('
' + str(item[1][0]) + '\n') - xml_file.write('
' + str(item[1][3]) + '\n') - xml_file.write('
' + str(item[1][2]) + '\n') -
- xml_file.write(' \n')
- xml_file.write(' \n')
-
- xml_file.write('\n')
-
-
-
-
-
- if __name__ == "__main__":
- yolo = YOLO()
- ss = Predict()
-
- #需要修改以下4个量,并且要去VOCdevkit/VOC2007/文件夹下替换训练好的模型best_epoch_weights.pth和voc_classes.txt
-
- pred_class = ["car", "moto", "persons"] # 填入需要检测的类名
- file_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image" # 填入测试的图片路径
- dir_save_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image_save"# 填入保存的图片路径
- xml_path="save\\xml_save\\"# 填入保存的xml文件的路径
-
- ls=os.listdir(file_path)
- for item in ls:
- img_name=item
- xml_name=img_name.split(".")[0]+".xml"
- img_names=img_name.split(".")[0]
-
- img_path=os.path.join(file_path,img_name)
- save_path=os.path.join(dir_save_path,img_name)
- #xml_path=os.path.join(xml_path,xml_name)
-
- pred=ss.a(img_path,save_path,img_name)
-
- img_xml(img_path, xml_path, img_names, pred)