本文后续会同步到阿里云。
训练之前要统计确定标签中的脚本都有哪些,以及不同种类的物体数量,确保没有部分目标过多或者过少的情况,导致训练结果不理想
话不多说直接上脚本
import os
import xml.etree.ElementTree as ET
import shutil
# 从所有类别中统计出不同物种的数据集个数,选择几个个数相近的数据集用于训练
classes = list()
ann_filepath = 'Annotations/'
cls_num = 0
matrix=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
def findclass(file):
global cls_num
tree = ET.parse(ann_filepath + '/' + file)
root = tree.getroot()
result = root.findall("object")
bool_num = 0
for obj in result:
cls = obj.find('name').text
if cls not in classes:
classes.append(cls)
cls_num = cls_num + 1
cls_id = classes.index(cls)
matrix[cls_id]=matrix[cls_id]+1
if __name__ == '__main__':
for f in os.listdir(ann_filepath):
findclass(f)
print(classes)
print(matrix)
classes = list():定义统计种类用的列表
global cls_num 是声明这个变量和全局的那个cls_num是同一个东西 不写会报错
xml.etree.ElementTree 是Python3 xml解析模块
ET.parse:从文件读取数据
tree.getroot:获取根节点(个人感觉是xml文件的最底层)
root.findall(“object”):先找文件夹object标签下的内容
cls = obj.find(‘name’).text:找到 object 内name标签下的内容,组成列表
for obj in result: 遍历单一文件下的每个标签,判定是否已经统计到种类中,不存在就append到列表后面,存在的话就直接根据index索引到它在matrix中元素的个数,在原有的基础上 +1。
for f in os.listdir(ann_filepath):遍历文件夹下的所有xml文件
最后是打印统计结果
遍历路径下的标签以及对应目标物,筛选出需要的标签种类,并且将数据集分类为训练集和测试集,将图像和标签保存到对应的文件夹下,目前代码实现的功能是选取每一种的前700张作为测试测,剩下的作为训练集。yolo 的一般处理是7:2:1或者7:3,我这里没保留验证集,有需要自己加。txt生成的是个没有用的东西,忘了注释了
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import xml.etree.ElementTree as ET
import shutil
#根据自己的情况修改相应的路径
ann_filepath = 'Annotations/'
img_filepath = 'JPEGImages/'
img_savepath = 'tra/'
val_savepath = 'val/' #val_savepath : 'val/'
annv_savepath = 'annv/'
annt_savepath = 'annt/'
# 路径不存在则创建
if not os.path.exists(img_savepath):
os.mkdir(img_savepath)
if not os.path.exists(annv_savepath):
os.mkdir(annv_savepath)
if not os.path.exists(annt_savepath):
os.mkdir(annt_savepath)
if not os.path.exists(val_savepath):
os.mkdir(val_savepath)
classes = ['missing_hole', 'mouse_bite', 'open_circuit', 'short', 'spurious_copper','spur']#这里是需要提取的类别
matrix=[0,0,0,0,0,0]
def save_annotation(file):
tree = ET.parse(ann_filepath + '/' + file)
root = tree.getroot()
result = root.findall("object")
bool_num = 0
val =False
for obj in result:
if obj.find("name").text not in classes:
root.remove(obj)
else:
bool_num = 1
cls = obj.find('name').text
cls_id = classes.index(cls)
matrix[cls_id]=matrix[cls_id]+1
if matrix[cls_id]<700 :
val =True
if bool_num:
if val:
tree.write(annv_savepath + file)
else :
tree.write(annt_savepath + file)
return True,val
else:
return False,val
def save_images(file,val):
#文本文件名自己定义,主要用于生成相应的训练或测试的txt文件
if val :
with open('val.txt', 'a') as file_txt:
name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"
shutil.copy(name_img, val_savepath)
file_txt.write(os.path.splitext(file)[0])
file_txt.write("\n")
else:
with open('tra.txt', 'a') as file_txt:
name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"
shutil.copy(name_img, img_savepath)
file_txt.write(os.path.splitext(file)[0])
file_txt.write("\n")
return True
# 使用select脚本把图片筛选出来存到val 和tra文件夹中,并且生成val.txt和tra.txt
if __name__ == '__main__':
for f in os.listdir(ann_filepath):
val = False
a,val = save_annotation(f)
if a :
save_images(f,val)