• YOLOv7训练自己的数据集


    目录

    1、制作YOLO格式数据集

    1.1、数据集

    1.2、如何转换为YOLOv7所需的格式?

    1.3、如何批量化生成YOLO格式的txt标注

    1.4、如何划分YOLO的train、val和test

    2、使用YOLOv7训练自己的模型

    2.1、测试预训练的yolov7.pt

    (1)测试图片

    (2)测试本地摄像头

    (3)测试视频流效果

    2.2、训练自己数据的YOLOv7模型

    2.3、测试自己训练的模型

    2.4、测试关键点检测


    YOLOv7下载地址:YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors


    1、制作YOLO格式数据集

    1.1、数据集

    本文采用的是EDS数据集:包含了来自 3 台不同 X 光机器的 14219 张图片, 其中 10 类物品, 共计 31655 个目标实例,均由专业标注人员进行标注。

    每一台机器对应一个数据集,分别对应domain1、domain2和domain3,下图对应每个数据集中的类别分布且相对均匀。

    代码显示部分图像:

    1. import matplotlib.pyplot as plt
    2. import glob
    3. import cv2
    4. def show_multi_img(imgpath,num):
    5. """
    6. :param imgpath: 图像地址
    7. :param num: 输出图像的数量:eg:6*6,一幅图展示36张
    8. :return:
    9. """
    10. img_path = glob.glob(imgpath+"/*")
    11. plt.figure()
    12. for i in range(1,num*num+1):
    13. img = cv2.imread(img_path[i])
    14. title = img_path[i].split("\\")[1]
    15. plt.subplot(num,num,i)
    16. plt.imshow(img)
    17. plt.title(title,fontsize=6)
    18. plt.xticks([])
    19. plt.yticks([])
    20. plt.axis("on")
    21. plt.savefig("final.png")
    22. plt.show()
    23. if __name__ == "__main__":
    24. image_dir = "./domain2/image"
    25. show_multi_img(image_dir,6)

    每个domain分别由image和txt组成:

     1.2、如何转换为YOLOv7所需的格式?

    首先来看一下yolo数据的标注:

     EDS数据集格式:

     假设图像的高度和宽度分别为H和W,bbox的左上角坐标为(xmin,ymin),右下角坐标为(xmax,ymax),则中心点(x_center,y_center),即

    x_center = xmin + (xmax - xmin)/2

    y_center = ymin + (ymax - ymin)/2

    W = xmax - xmin

    H = ymax - ymin

    则YOLO数据格式为:label, x_, y_, w_, h_,则有对应关系:

    x_ = x_center / img_width

    y_ = y_center / img_height

    w_ = W / img_width

    h_ = H / img_height

    其中label对应的是数字,需要将EDS中的类名转换为数字表示img_widthimg_height为图像的原始的宽度和高度,可以通过cv2.imread()读取,然后shape获取宽度和高度

    1. img= cv2.imread("./domain/image/00001.jpg")
    2. img_height,img_width,_ = img.shape

    显示一幅图像并将bbox绘制在原图中:

    1. import cv2
    2. f = open("./domain1/txt/00004.txt",encoding="utf-8")
    3. img = cv2.imread('./domain1/image/00004.jpg')
    4. img_height,img_width,_ = img.shape
    5. for line in f.readlines():
    6. text = str(line.split(" ")[1])
    7. xmin = float(line.split(" ")[2])
    8. ymin = float(line.split(" ")[3])
    9. xmax = float(line.split(" ")[4])
    10. ymax = float(line.split(" ")[5])
    11. print("xmin:{},xmax:{},ymin:{},ymax:{}".format(xmin,xmax,ymin,ymax))
    12. x_center = xmin + (xmax - xmin) / 2
    13. y_center = ymin + (ymax - ymin) / 2
    14. w = xmax - xmin
    15. h = ymax - ymin
    16. # 保留6位小数
    17. x_center = round(x_center / img_width, 6)
    18. y_center = round(y_center / img_height, 6)
    19. w = round(w / img_width, 6)
    20. h = round(h / img_height, 6)
    21. # print(x_center,y_center,w,h)
    22. # 将yolo格式转换原始的格式进行验证
    23. x1 = int((float(x_center)-float(w)/2)*img_width)
    24. y1 = int((float(y_center) - float(h) / 2) * img_height)
    25. x2 = int((float(x_center) + float(w) / 2) * img_width)
    26. y2 = int((float(y_center) + float(h) / 2) * img_height)
    27. print(x1,y1,x2,y2)
    28. cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,255),3)
    29. cv2.putText(img,text,(int(xmin),int(ymin)-5),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
    30. cv2.imshow("show",img)
    31. cv2.waitKey(0)
    32. cv2.imwrite("bbox.png",img)

    转换前:xmin:84.0,ymin:369.0,xmax:342.0,ymax:554.0
    转换后:xmin:83,ymin:368,xmax:341,ymax:553
    转换前:xmin:210.0,ymin:409.0,xmax:591.0,ymax:691.0
    转换后:xmin:210,ymin:409,xmax:591,ymax:691
    转换前:xmin:182.0,ymin:457.0,xmax:364.0,ymax:550.0
    转换后:xmin:181,ymin:456,xmax:364,ymax:549

    -------------------------------------------------------------------------------------

    这里还是存在一些转换的误差,不过影响没那么大。

    注意:如果没有的标注数据可以用,可以下载LabelImg,进行YOLO格式的数据集标注,直接生存对应的yolo格式的数据集。

    1.3、如何批量化生成YOLO格式的txt标注

    1. import glob
    2. import os
    3. import cv2
    4. txt_file = r".\domain1\txt"
    5. name = glob.glob(os.path.join(txt_file,"*.txt"))
    6. list_1 = []
    7. for i in name:
    8. f = open(i,encoding="utf-8")
    9. byt = f.readlines()
    10. for line in byt:
    11. list_1.append(line.split(" ")[1])
    12. x = line.split(" ")[2]
    13. y = line.split(" ")[3]
    14. w = line.split(" ")[4]
    15. h = line.split(" ")[5]
    16. # print(x,y,w,h)
    17. # 读取所有txt中的目标,并去重
    18. list2 = list(set(list_1))
    19. # print(list2)
    20. l = {} # EDS数据类名对应的数字
    21. j = 0
    22. for i in list2:
    23. l[i] = j
    24. j += 1
    25. print(l) #对应的字典形式
    26. # yolov7的第一列是cls_id x y w h 其中坐标(x,y)是中心点坐标,并且是相对于图片宽高的比例值 ,并非绝对坐标
    27. img_path = "./domain1/image"
    28. out_path = "./out"
    29. list_1 = []
    30. name = glob.glob(os.path.join(txt_file,"*.txt"))
    31. for i in name:
    32. if not os.path.exists(out_path):
    33. os.mkdir(out_path)
    34. with open(os.path.join(out_path, i.split("\\")[3].split(".")[0] + ".txt"), "w") as f_1:
    35. img_name = i.split("\\")[3].split(".")[0] + ".jpg"
    36. img = os.path.join(img_path,img_name)
    37. img_ = cv2.imread(img)
    38. img_height, img_width, _ = img_.shape
    39. f = open(i,encoding="utf-8")
    40. byt = f.readlines()
    41. for line in byt:
    42. class_num = l[line.split(" ")[1]]
    43. xmin = float(line.split(" ")[2])
    44. ymin = float(line.split(" ")[3])
    45. xmax = float(line.split(" ")[4])
    46. ymax = float(line.split(" ")[5])
    47. x_center = xmin + (xmax - xmin) / 2
    48. y_center = ymin + (ymax - ymin) / 2
    49. w = xmax - xmin
    50. h = ymax - ymin
    51. x_center = round(x_center / img_width, 6)
    52. y_center = round(y_center / img_height, 6)
    53. w = round(w / img_width, 6)
    54. h = round(h / img_height, 6)
    55. info = [str(i) for i in [class_num, x_center, y_center, w, h]]
    56. print(info)
    57. f_1.write(" ".join(info)+"\n")

    1.4、如何划分YOLO的train、val和test

    本文制作好的数据集:YOLO格式的EDS数据集,免费欢迎下载!感谢支持!

    1. # 将图片和标注数据按比例切分为 训练集和测试集
    2. import shutil
    3. import random
    4. import os
    5. # 原始路径,需要修改
    6. image_original_path = './domain1/image/'
    7. label_original_path = './out/'
    8. # 训练集路径,需要修改
    9. train_image_path = 'E:\yolov7\data\images\\train\\'
    10. train_label_path = 'E:\yolov7\data\labels\\train\\'
    11. # 验证集路径,需要修改
    12. val_image_path = 'E:\yolov7\data\images\\val\\'
    13. val_label_path = 'E:\yolov7\data\labels\\val\\'
    14. # 测试集路径,需要修改
    15. test_image_path = 'E:\yolov7\data\images\\test\\'
    16. test_label_path = 'E:\yolov7\data\labels\\test\\'
    17. # 数据集划分比例,训练集75%,验证集15%,测试集15%,按需修改
    18. train_percent = 0.7
    19. val_percent = 0.15
    20. test_percent = 0.1
    21. # 检查文件夹是否存在
    22. def mkdir():
    23. if not os.path.exists(train_image_path):
    24. os.makedirs(train_image_path)
    25. if not os.path.exists(train_label_path):
    26. os.makedirs(train_label_path)
    27. if not os.path.exists(val_image_path):
    28. os.makedirs(val_image_path)
    29. if not os.path.exists(val_label_path):
    30. os.makedirs(val_label_path)
    31. if not os.path.exists(test_image_path):
    32. os.makedirs(test_image_path)
    33. if not os.path.exists(test_label_path):
    34. os.makedirs(test_label_path)
    35. def main():
    36. mkdir()
    37. total_txt = os.listdir(label_original_path)
    38. num_txt = len(total_txt)
    39. list_all_txt = range(num_txt) # 范围 range(0, num)
    40. num_train = int(num_txt * train_percent)
    41. num_val = int(num_txt * val_percent)
    42. num_test = num_txt - num_train - num_val
    43. train = random.sample(list_all_txt, num_train)
    44. # train从list_all_txt取出num_train个元素
    45. # 所以list_all_txt列表只剩下了这些元素:val_test
    46. val_test = [i for i in list_all_txt if not i in train]
    47. # 再从val_test取出num_val个元素,val_test剩下的元素就是test
    48. val = random.sample(val_test, num_val)
    49. print("训练集数目:{}, 验证集数目:{},测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
    50. for i in list_all_txt:
    51. name = total_txt[i][:-4]
    52. srcImage = image_original_path + name + '.jpg'
    53. srcLabel = label_original_path + name + '.txt'
    54. if i in train:
    55. dst_train_Image = train_image_path + name + '.jpg'
    56. dst_train_Label = train_label_path + name + '.txt'
    57. shutil.copyfile(srcImage, dst_train_Image)
    58. shutil.copyfile(srcLabel, dst_train_Label)
    59. elif i in val:
    60. dst_val_Image = val_image_path + name + '.jpg'
    61. dst_val_Label = val_label_path + name + '.txt'
    62. shutil.copyfile(srcImage, dst_val_Image)
    63. shutil.copyfile(srcLabel, dst_val_Label)
    64. else:
    65. dst_test_Image = test_image_path + name + '.jpg'
    66. dst_test_Label = test_label_path + name + '.txt'
    67. shutil.copyfile(srcImage, dst_test_Image)
    68. shutil.copyfile(srcLabel, dst_test_Label)
    69. if __name__ == '__main__':
    70. main()

    2、使用YOLOv7训练自己的模型

    官方地址:https://github.com/wongkinyiu/yolov7

    采用git拉取:

    git clone https://github.com/wongkinyiu/yolov7

    2.1、测试预训练的yolov7.pt

    官网提供了下载链接,可以直接下载,或者直接从csdn里下载:YOLOv7预训练权重

    预训练权重下载完成后,打开detect.py

    直接运行即可,其他都选择默认的参数!

    (1)测试图片

    或者修改--source为自己的图像路径,同样也可以修改--weights=your_weight_path,测试自己训练的模型

    测试yolov7.pt的识别效果 

     

    看下官方提供处理图像的代码:utils.datasets

    1. class LoadImages: # for inference
    2. def __init__(self, path, img_size=640, stride=32):
    3. """
    4. path:图像路径
    5. img_size:最终要测试的图像尺寸
    6. stride:这个主要用于pad一些小的图像以满足实际测试图像的尺寸
    7. return:
    8. path:图像的路径
    9. img:resize后的图像
    10. img0:原始图像
    11. self.cap
    12. """
    13. # 遍历输入的测试图像路径,files保存测试的地址
    14. p = str(Path(path).absolute()) # os-agnostic absolute path
    15. if '*' in p:
    16. files = sorted(glob.glob(p, recursive=True)) # glob
    17. elif os.path.isdir(p):
    18. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
    19. elif os.path.isfile(p):
    20. files = [p] # files
    21. else:
    22. raise Exception(f'ERROR: {p} does not exist')
    23. # 通过图像地址的后缀判断是图像还是视频,然后用list格式保存
    24. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
    25. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
    26. # 文件里总共有多少和多少视频
    27. ni, nv = len(images), len(videos)
    28. self.img_size = img_size
    29. self.stride = stride
    30. self.files = images + videos # list格式
    31. self.nf = ni + nv # number of files
    32. self.video_flag = [False] * ni + [True] * nv # 用于判断是不是视频
    33. self.mode = 'image'
    34. if any(videos): # 判断videos是否存在
    35. self.new_video(videos[0]) # new video
    36. else:
    37. self.cap = None
    38. assert self.nf > 0, f'No images or videos found in {p}. ' \
    39. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
    40. # __iter__迭代器,系统定义的名字
    41. def __iter__(self):
    42. self.count = 0
    43. return self
    44. def __next__(self):
    45. if self.count == self.nf:
    46. raise StopIteration
    47. path = self.files[self.count]
    48. if self.video_flag[self.count]:
    49. # Read video
    50. self.mode = 'video'
    51. ret_val, img0 = self.cap.read()
    52. if not ret_val:
    53. self.count += 1
    54. self.cap.release()
    55. if self.count == self.nf: # last video
    56. raise StopIteration
    57. else:
    58. path = self.files[self.count]
    59. self.new_video(path)
    60. ret_val, img0 = self.cap.read()
    61. self.frame += 1
    62. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
    63. else:
    64. # Read image
    65. self.count += 1
    66. img0 = cv2.imread(path) # BGR
    67. assert img0 is not None, 'Image Not Found ' + path
    68. #print(f'image {self.count}/{self.nf} {path}: ', end='')
    69. # Padded resize
    70. img = letterbox(img0, self.img_size, stride=self.stride)[0]
    71. # Convert
    72. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
    73. img = np.ascontiguousarray(img)
    74. return path, img, img0, self.cap
    75. def new_video(self, path):
    76. self.frame = 0
    77. self.cap = cv2.VideoCapture(path)
    78. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
    79. def __len__(self):
    80. return self.nf # number of files

    (2)测试本地摄像头

    简单的一个获取本地摄像头的代码

    1. import cv2
    2. def access_camera(url,output_path):
    3. if url == int(0):
    4. cap = cv2.VideoCapture(url)
    5. else:
    6. cap = cv2.VideoCapture(url)
    7. while(cap.isOpened()):
    8. # Capture frame-by-frame
    9. ret, frame = cap.read()
    10. # Display the resulting frame
    11. cv2.imshow('frame',frame)
    12. cv2.imwrite(output_path,frame)
    13. print("图像保存成功!")
    14. if cv2.waitKey(1) & 0xFF == ord('q'):
    15. break
    16. # When everything done, release the capture
    17. cap.release()
    18. cv2.destroyAllWindows()
    19. if __name__ == "__main__":
    20. url = 'http://admin:admin@192.168.1.3:8081/video' # 调用IP摄像机
    21. output_path = "./runs/detect/img.png"
    22. # url = 0 # 调用笔记本摄像头
    23. access_camera(url,output_path)

    yolov7提供的代码,其实思路是一样的

    1. class LoadWebcam: # for inference
    2. def __init__(self, pipe='0', img_size=640, stride=32):
    3. """
    4. pipe:0表示使用本地摄像头
    5. img_size:图像大小
    6. stride:
    7. """
    8. self.img_size = img_size
    9. self.stride = stride
    10. if pipe.isnumeric():
    11. pipe = eval(pipe) # local camera
    12. # pipe = 'rtsp://192.168.1.64/1' # IP camera
    13. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
    14. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
    15. self.pipe = pipe
    16. self.cap = cv2.VideoCapture(pipe) # video capture object
    17. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
    18. def __iter__(self):
    19. self.count = -1
    20. return self
    21. def __next__(self):
    22. self.count += 1
    23. if cv2.waitKey(1) == ord('q'): # q to quit
    24. self.cap.release()
    25. cv2.destroyAllWindows()
    26. raise StopIteration
    27. # Read frame
    28. if self.pipe == 0: # local camera
    29. ret_val, img0 = self.cap.read()
    30. img0 = cv2.flip(img0, 1) # flip left-right
    31. else: # IP camera
    32. n = 0
    33. while True:
    34. n += 1
    35. self.cap.grab()
    36. if n % 30 == 0: # skip frames
    37. ret_val, img0 = self.cap.retrieve()
    38. if ret_val:
    39. break
    40. # Print
    41. assert ret_val, f'Camera Error {self.pipe}'
    42. img_path = 'webcam.jpg'
    43. print(f'webcam {self.count}: ', end='')
    44. # Padded resize
    45. img = letterbox(img0, self.img_size, stride=self.stride)[0]
    46. # Convert
    47. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
    48. img = np.ascontiguousarray(img)
    49. return img_path, img, img0, None
    50. def __len__(self):
    51. return 0

    (3)测试视频流效果

    1. class LoadStreams: # multiple IP or RTSP cameras
    2. def __init__(self, sources='streams.txt', img_size=640, stride=32):
    3. self.mode = 'stream'
    4. self.img_size = img_size
    5. self.stride = stride
    6. if os.path.isfile(sources):
    7. with open(sources, 'r') as f:
    8. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
    9. else:
    10. sources = [sources]
    11. n = len(sources)
    12. self.imgs = [None] * n
    13. self.sources = [clean_str(x) for x in sources] # clean source names for later
    14. for i, s in enumerate(sources):
    15. # Start the thread to read frames from the video stream
    16. print(f'{i + 1}/{n}: {s}... ', end='')
    17. url = eval(s) if s.isnumeric() else s
    18. if 'youtube.com/' in str(url) or 'youtu.be/' in str(url): # if source is YouTube video
    19. check_requirements(('pafy', 'youtube_dl'))
    20. import pafy
    21. url = pafy.new(url).getbest(preftype="mp4").url
    22. cap = cv2.VideoCapture(url)
    23. assert cap.isOpened(), f'Failed to open {s}'
    24. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    25. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    26. self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
    27. _, self.imgs[i] = cap.read() # guarantee first frame
    28. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
    29. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
    30. thread.start()
    31. print('') # newline
    32. # check for common shapes
    33. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
    34. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
    35. if not self.rect:
    36. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
    37. def update(self, index, cap):
    38. # Read next stream frame in a daemon thread
    39. n = 0
    40. while cap.isOpened():
    41. n += 1
    42. # _, self.imgs[index] = cap.read()
    43. cap.grab()
    44. if n == 4: # read every 4th frame
    45. success, im = cap.retrieve()
    46. self.imgs[index] = im if success else self.imgs[index] * 0
    47. n = 0
    48. time.sleep(1 / self.fps) # wait time
    49. def __iter__(self):
    50. self.count = -1
    51. return self
    52. def __next__(self):
    53. self.count += 1
    54. img0 = self.imgs.copy()
    55. if cv2.waitKey(1) == ord('q'): # q to quit
    56. cv2.destroyAllWindows()
    57. raise StopIteration
    58. # Letterbox
    59. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
    60. # Stack
    61. img = np.stack(img, 0)
    62. # Convert
    63. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
    64. img = np.ascontiguousarray(img)
    65. return self.sources, img, img0, None
    66. def __len__(self):
    67. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years

    获取到所有的图像或者视频流,然后将获取的图像输入对应的model中,查看一下官方提供的detect.py代码

    1. def detect(save_img=False):
    2. source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
    3. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
    4. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
    5. ('rtsp://', 'rtmp://', 'http://', 'https://'))
    6. # Directories
    7. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
    8. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
    9. # Initialize
    10. set_logging()
    11. device = select_device(opt.device)
    12. half = device.type != 'cpu' # half precision only supported on CUDA
    13. # Load model
    14. model = attempt_load(weights, map_location=device) # load FP32 model
    15. stride = int(model.stride.max()) # model stride
    16. imgsz = check_img_size(imgsz, s=stride) # check img_size
    17. if trace:
    18. model = TracedModel(model, device, opt.img_size)
    19. if half:
    20. model.half() # to FP16
    21. # Second-stage classifier
    22. classify = False
    23. if classify:
    24. modelc = load_classifier(name='resnet101', n=2) # initialize
    25. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
    26. # Set Dataloader
    27. vid_path, vid_writer = None, None
    28. if webcam:
    29. view_img = check_imshow()
    30. cudnn.benchmark = True # set True to speed up constant image size inference
    31. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
    32. else:
    33. dataset = LoadImages(source, img_size=imgsz, stride=stride)
    34. # Get names and colors
    35. names = model.module.names if hasattr(model, 'module') else model.names
    36. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
    37. # Run inference
    38. if device.type != 'cpu':
    39. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
    40. t0 = time.time()
    41. for path, img, im0s, vid_cap in dataset:
    42. img = torch.from_numpy(img).to(device)
    43. img = img.half() if half else img.float() # uint8 to fp16/32
    44. img /= 255.0 # 0 - 255 to 0.0 - 1.0
    45. if img.ndimension() == 3:
    46. img = img.unsqueeze(0)
    47. # Inference
    48. t1 = time_synchronized()
    49. pred = model(img, augment=opt.augment)[0]
    50. # Apply NMS
    51. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
    52. t2 = time_synchronized()
    53. # Apply Classifier
    54. if classify:
    55. pred = apply_classifier(pred, modelc, img, im0s)
    56. # Process detections
    57. for i, det in enumerate(pred): # detections per image
    58. if webcam: # batch_size >= 1
    59. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
    60. else:
    61. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
    62. p = Path(p) # to Path
    63. save_path = str(save_dir / p.name) # img.jpg
    64. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
    65. s += '%gx%g ' % img.shape[2:] # print string
    66. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
    67. if len(det):
    68. # Rescale boxes from img_size to im0 size
    69. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
    70. # Print results
    71. for c in det[:, -1].unique():
    72. n = (det[:, -1] == c).sum() # detections per class
    73. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
    74. # Write results
    75. for *xyxy, conf, cls in reversed(det):
    76. if save_txt: # Write to file
    77. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
    78. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
    79. with open(txt_path + '.txt', 'a') as f:
    80. f.write(('%g ' * len(line)).rstrip() % line + '\n')
    81. if save_img or view_img: # Add bbox to image
    82. label = f'{names[int(cls)]} {conf:.2f}'
    83. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
    84. # Print time (inference + NMS)
    85. #print(f'{s}Done. ({t2 - t1:.3f}s)')
    86. # Stream results
    87. if view_img:
    88. cv2.imshow(str(p), im0)
    89. cv2.waitKey(1) # 1 millisecond
    90. # Save results (image with detections)
    91. if save_img:
    92. if dataset.mode == 'image':
    93. cv2.imwrite(save_path, im0)
    94. print(f" The image with the result is saved in: {save_path}")
    95. else: # 'video' or 'stream'
    96. if vid_path != save_path: # new video
    97. vid_path = save_path
    98. if isinstance(vid_writer, cv2.VideoWriter):
    99. vid_writer.release() # release previous video writer
    100. if vid_cap: # video
    101. fps = vid_cap.get(cv2.CAP_PROP_FPS)
    102. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    103. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    104. else: # stream
    105. fps, w, h = 30, im0.shape[1], im0.shape[0]
    106. save_path += '.mp4'
    107. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
    108. vid_writer.write(im0)
    109. if save_txt or save_img:
    110. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
    111. #print(f"Results saved to {save_dir}{s}")
    112. print(f'Done. ({time.time() - t0:.3f}s)')

    2.2、训练自己数据的YOLOv7模型

    根据前面的方法制作自己数据集,放在yolov7/data目录下

     在yolov7/data目录下创建一个dataset.yaml文件,可以参考官方提供的coco.yaml

     自己的配置文件dataset.yaml

    1. train: E:/yolov7/data/images/train # train images
    2. val: E:/yolov7/data/images/val # val images
    3. test: E:/yolov7/data/images/test # test images (optional)
    4. # Classes
    5. nc: 10 # number of classes
    6. names: ['laptop','pressure','device','plasticbottle','scissor','knife','lighter','powerbank','glassbottle','umbrella'] # class names

    开始训练。。。漫长的等待了,最终所有的训练信息都保存在yolov7/runs/train/exp目录下

    2.3、测试自己训练的模型

    修改detect.py中的weights地址 ,这个模型我只训练了10次,效果也还行。

    2.4、测试关键点检测

    首先下载官方提供的预训练模型yolov7-w6-pose.pt

    1. import matplotlib
    2. """
    3. 未使用matplotlib.use('TkAgg')
    4. 出现问题:UserWarning: Matplotlib is currently using agg, which is a non-GUI backend
    5. """
    6. matplotlib.use('TkAgg')
    7. import matplotlib.pyplot as plt
    8. print(matplotlib.get_backend())
    9. import torch
    10. import cv2
    11. from torchvision import transforms
    12. import numpy as np
    13. from utils.datasets import letterbox
    14. # 查看GUI backbend环境,主要是查看每个调用的代码环境下是否都是相同的环境
    15. print(matplotlib.get_backend())
    16. from utils.general import non_max_suppression_kpt
    17. print(matplotlib.get_backend())
    18. from utils.plots import output_to_keypoint, plot_skeleton_kpts
    19. # plots中matplotlib.get_backend()设置不同,记得修改
    20. print(matplotlib.get_backend())
    21. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    22. weigths = torch.load('../yolov7-w6-pose.pt')
    23. model = weigths['model']
    24. model = model.half().to(device)
    25. _ = model.eval()
    26. image = cv2.imread('../person.jpeg')
    27. image = letterbox(image, 960, stride=64, auto=True)[0]
    28. image_ = image.copy()
    29. image = transforms.ToTensor()(image)
    30. image = torch.tensor(np.array([image.numpy()]))
    31. image = image.to(device)
    32. image = image.half()
    33. output, _ = model(image)
    34. output = non_max_suppression_kpt(output, 0.25, 0.65, nc=model.yaml['nc'], nkpt=model.yaml['nkpt'], kpt_label=True)
    35. output = output_to_keypoint(output)
    36. nimg = image[0].permute(1, 2, 0) * 255
    37. nimg = nimg.cpu().numpy().astype(np.uint8)
    38. nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
    39. for idx in range(output.shape[0]):
    40. plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
    41. plt.figure(figsize=(8,8))
    42. plt.axis('off')
    43. plt.imshow(nimg)
    44. plt.savefig("person_detection.png")
    45. plt.show()

     报错信息:

     在utils.plots.py中的442-443行中的增加detach()不在进行反向传播即可

     未完待续。。。

  • 相关阅读:
    java计算机毕业设计基于springboo高校学报论文在线投稿系统
    微服务-Eureka
    「SpringCloud」10 Stream消息驱动
    【C++】多态
    面试五 -bind 和 function
    动态规划:11分割等和子集
    [Linux打怪升级之路]-yun安装和gcc的使用
    神经网络算法有哪些模型,神经网络模型应用实例
    我的创作纪念日——创作者2年
    通过 filesystem 的 inode 设计,理解数组与链表
  • 原文地址:https://blog.csdn.net/weixin_43687366/article/details/126166982