• 【目标检测】yolov5的pth模型转onnx及模型推理


    1. yolov5源代码

    https://github.com/ultralytics/yolov5

    2. pth2onnx

    在这里插入图片描述
    以上内容都可以在源码中找到
    pth2onnx的代码如下:

    #-*- codeing = utf-8 -*-
    #@Function: 
    #@Time : 2022/4/19 18:33
    #@Author : yx
    #@File : pth2onnx.py
    #@Software : PyCharm
    
    
    import argparse
    import sys
    import time
    
    sys.path.append('./')  # to run '$ python *.py' files in subdirectories
    
    import torch
    import torch.nn as nn
    
    import models
    from models.experimental import attempt_load
    from utils.activations import Hardswish, SiLU
    from utils.general import set_logging, check_img_size
    import onnx
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--weights', type=str, default='./yolov5m.pt', help='weights path')  # from yolov5/models/
        parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], help='image size')  # height, width
        parser.add_argument('--batch_size', type=int, default=1, help='batch size')
        parser.add_argument('--simplify', action='store_true', default=False, help='simplify onnx')
        parser.add_argument('--dynamic', action='store_true', default=False, help='enable dynamic axis in onnx model')
        parser.add_argument('--onnx2pb', action='store_true', default=False, help='export onnx to pb')
        parser.add_argument('--onnx_infer', action='store_true', default=True, help='onnx infer test')
    
        opt = parser.parse_args()
        opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand
        print(opt)
        set_logging()
        t = time.time()
    
        # Load PyTorch model
        model = attempt_load(opt.weights, map_location=torch.device('cpu'))  # load FP32 model
        delattr(model.model[-1], 'anchor_grid')
        model.model[-1].anchor_grid=[torch.zeros(1)] * 3 # nl=3 number of detection layers
        model.model[-1].export_cat = True
        model.eval()
        labels = model.names
    
        # Checks
        gs = int(max(model.stride))  # grid size (max stride)
        opt.img_size = [check_img_size(x, gs) for x in opt.img_size]  # verify img_size are gs-multiples
    
        # Input
        img = torch.zeros(opt.batch_size, 3, *opt.img_size)  # image size(1,3,320,192) iDetection
    
        # Update model
        for k, m in model.named_modules():
            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
            if isinstance(m, models.common.Conv):  # assign export-friendly activations
                if isinstance(m.act, nn.Hardswish):
                    m.act = Hardswish()
                elif isinstance(m.act, nn.SiLU):
                    m.act = SiLU()
            # elif isinstance(m, models.yolo.Detect):
            #     m.forward = m.forward_export  # assign forward (optional)
            if isinstance(m, models.common.ShuffleV2Block):#shufflenet block nn.SiLU
                for i in range(len(m.branch1)):
                    if isinstance(m.branch1[i], nn.SiLU):
                        m.branch1[i] = SiLU()
                for i in range(len(m.branch2)):
                    if isinstance(m.branch2[i], nn.SiLU):
                        m.branch2[i] = SiLU()
        y = model(img)  # dry run
    
        # ONNX export
        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = opt.weights.replace('.pt', '.onnx')  # filename
        model.fuse()  # only for ONNX
        input_names=['input']
        output_names=['output']
        torch.onnx.export(model, img, f, verbose=False, opset_version=12,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes = {'input': {0: 'batch'},
                            'output': {0: 'batch'}
                            } if opt.dynamic else None)
    
        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
    
        # https://github.com/daquexian/onnx-simplifier
        if opt.simplify:
            try:
                import onnxsim
                print(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
                onnx_model, check = onnxsim.simplify(onnx_model,
                    dynamic_input_shape=opt.dynamic,
                    input_shapes={'input': list(img.shape)} if opt.dynamic else None)
                assert check, "simplify check failed "
                onnx.save(onnx_model, f)
            except Exception as e:
                print(f"simplifer failure: {e}")
    
        print('ONNX export success, saved as %s' % f)
        # Finish
        print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
    
    
        # onnx infer
        if opt.onnx_infer:
            import onnxruntime
            import numpy as np
            providers =  ['CPUExecutionProvider']
            session = onnxruntime.InferenceSession(f, providers=providers)
            im = img.cpu().numpy().astype(np.float32) # torch to numpy
            y_onnx = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: im})[0]
            print("pred's shape is ",y_onnx.shape)
            print("max(|torch_pred - onnx_pred|) =",abs(y.cpu().numpy()-y_onnx).max())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118

    3. 模型推理

    在这里插入图片描述

    3.1 infer

    # coding:gbk
    # coding:utf-8
    import cv2.cv2 as cv2
    import numpy as np
    import onnxruntime
    import torch
    import torchvision
    import time
    import random
    from utils.general import non_max_suppression
    import pandas as pd
    
    
    class YOLOV5_ONNX(object):
    	def __init__(self,onnx_path):
    		'''初始化onnx'''
    		self.onnx_session=onnxruntime.InferenceSession(onnx_path)
    		print(onnxruntime.get_device())
    		self.input_name=self.get_input_name()
    		self.output_name=self.get_output_name()
    		# self.classes=['person', 'bicycle', 'car', 'motorcycle','airplane','bus','train',
    		# 			  'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter',
    		# 			  'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
    		# 			  'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie']
    
    	# 使用函数读取self.classes
    	def readClasses(self, txtPath):
    		classes = []
    		test = pd.read_csv(txtPath, header=None)  # 这个是没有标题的文件
    		for i in test.index:
    			data = test.loc[i].values
    			classes.append(data[2])
    		return classes
    
    	def get_input_name(self):
    		'''获取输入节点名称'''
    		input_name=[]
    		for node in self.onnx_session.get_inputs():
    			input_name.append(node.name)
    
    		return input_name
    
    
    	def get_output_name(self):
    		'''获取输出节点名称'''
    		output_name=[]
    		for node in self.onnx_session.get_outputs():
    			output_name.append(node.name)
    
    		return output_name
    
    	def get_input_feed(self,image_tensor):
    		'''获取输入tensor'''
    		input_feed={}
    		for name in self.input_name:
    			input_feed[name]=image_tensor
    
    		return input_feed
    
    	def letterbox(self,img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True,
    				  stride=32):
    		'''图片归一化'''
    		# Resize and pad image while meeting stride-multiple constraints
    		shape = img.shape[:2]  # current shape [height, width]
    		if isinstance(new_shape, int):
    			new_shape = (new_shape, new_shape)
    
    		# Scale ratio (new / old)
    		r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    		if not scaleup:  # only scale down, do not scale up (for better test mAP)
    			r = min(r, 1.0)
    
    		# Compute padding
    		ratio = r, r  # width, height ratios
    
    		new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    		dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    
    		if auto:  # minimum rectangle
    			dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    		elif scaleFill:  # stretch
    			dw, dh = 0.0, 0.0
    			new_unpad = (new_shape[1], new_shape[0])
    			ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
    
    		dw /= 2  # divide padding into 2 sides
    		dh /= 2
    
    		if shape[::-1] != new_unpad:  # resize
    			img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    
    		top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    		left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    
    		img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    		return img, ratio, (dw, dh)
    
    	def xywh2xyxy(self,x):
    		# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    		y = np.copy(x)
    
    		y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    		y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    		y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    		y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    
    		return y
    
    	def nms(self,prediction, conf_thres=0.1, iou_thres=0.6, agnostic=False):
    		if prediction.dtype is torch.float16:
    			prediction = prediction.float()  # to FP32
    		xc = prediction[..., 4] > conf_thres  # candidates
    		min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    		max_det = 300  # maximum number of detections per image
    		output = [None] * prediction.shape[0]
    		for xi, x in enumerate(prediction):  # image index, image inference
    			x = x[xc[xi]]  # confidence
    			if not x.shape[0]:
    				continue
    
    			x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
    			box = self.xywh2xyxy(x[:, :4])
    
    			conf, j = x[:, 5:].max(1, keepdim=True)
    			x = torch.cat((torch.tensor(box), conf, j.float()), 1)[conf.view(-1) > conf_thres]
    			n = x.shape[0]  # number of boxes
    			if not n:
    				continue
    			c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
    			boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
    			i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
    			if i.shape[0] > max_det:  # limit detections
    				i = i[:max_det]
    			output[xi] = x[i]
    		return output
    
    	def clip_coords(self,boxes, img_shape):
    		'''查看是否越界'''
    		# Clip bounding xyxy bounding boxes to image shape (height, width)
    		boxes[:, 0].clamp_(0, img_shape[1])  # x1
    		boxes[:, 1].clamp_(0, img_shape[0])  # y1
    		boxes[:, 2].clamp_(0, img_shape[1])  # x2
    		boxes[:, 3].clamp_(0, img_shape[0])  # y2
    
    	def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
    		'''
    		坐标对应到原始图像上,反操作:减去pad,除以最小缩放比例
    		:param img1_shape: 输入尺寸
    		:param coords: 输入坐标
    		:param img0_shape: 映射的尺寸
    		:param ratio_pad:
    		:return:
    		'''
    
    		# Rescale coords (xyxy) from img1_shape to img0_shape
    		if ratio_pad is None:  # calculate from img0_shape
    			gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new,计算缩放比率
    			pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (
    						img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding ,计算扩充的尺寸
    		else:
    			gain = ratio_pad[0][0]
    			pad = ratio_pad[1]
    
    		coords[:, [0, 2]] -= pad[0]  # x padding,减去x方向上的扩充
    		coords[:, [1, 3]] -= pad[1]  # y padding,减去y方向上的扩充
    		coords[:, :4] /= gain  # 将box坐标对应到原始图像上
    		self.clip_coords(coords, img0_shape)  # 边界检查
    		return coords
    
    	def sigmoid(self,x):
    		return 1 / (1 + np.exp(-x))
    
    
    
    	def infer(self,img_path):
    		'''执行前向操作预测输出'''
    		# 超参数设置
    		img_size=(640,640) #图片缩放大小
    		# 读取图片
    		src_img=cv2.imread(img_path)
    		start=time.time()
    		src_size=src_img.shape[:2]
    
    		# 图片填充并归一化
    		img=self.letterbox(src_img,img_size,stride=32)[0]
    
    		# Convert
    		img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    		img = np.ascontiguousarray(img)
    
    
    		# 归一化
    		img=img.astype(dtype=np.float32)
    		img/=255.0
    
    		# # BGR to RGB
    		# img = img[:, :, ::-1].transpose(2, 0, 1)
    		# img = np.ascontiguousarray(img)
    
    		# 维度扩张
    		img=np.expand_dims(img,axis=0)
    		print('img resuming: ',time.time()-start)
    		# 前向推理
    		# start=time.time()
    		input_feed=self.get_input_feed(img)
    		# ort_inputs = {self.onnx_session.get_inputs()[0].name: input_feed[None].numpy()}
    		pred = torch.tensor(self.onnx_session.run(None, input_feed)[0])
    		results = non_max_suppression(pred, 0.5,0.5)
    		print('onnx resuming: ',time.time()-start)
    		# pred=self.onnx_session.run(output_names=self.output_name,input_feed=input_feed)
    
    		#映射到原始图像
    		img_shape=img.shape[2:]
    		# print(img_size)
    		for det in results:  # detections per image
    			if det is not None and len(det):
    				det[:, :4] = self.scale_coords(img_shape, det[:, :4],src_size).round()
    		print(time.time()-start)
    		if det is not None and len(det):
    			self.draw(src_img, det)
    
    
    	def plot_one_box(self,x, img, color=None, label=None, line_thickness=None):
    		# Plots one bounding box on image img
    		tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    		color = color or [random.randint(0, 255) for _ in range(3)]
    		c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    		cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    		if label:
    			tf = max(tl - 1, 1)  # font thickness
    			t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
    			c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
    			cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
    			cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    
    	def draw(self,img, boxinfo):
    		txt_path = "labels.txt"
    		self.classes = self.readClasses(txt_path)
    		colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(self.classes))]
    		for *xyxy, conf, cls in boxinfo:
    			label = '%s %.2f' % (self.classes[int(cls)], conf)
    			self.plot_one_box(xyxy, img, label=label , color=colors[0], line_thickness=2)
    
    		cv2.namedWindow("dst",0)
    		cv2.imshow("dst", img)
    		cv2.imwrite("res1.jpg",img)
    		cv2.waitKey(0)
    
    		return 0
    
    
    if __name__=="__main__":
    	model=YOLOV5_ONNX(onnx_path="yolov5s.onnx")
    	model.infer(img_path="zidane.jpg")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254

    3.2 labels.txt

    1,1,person
    2,2,bicycle
    3,3,car
    4,4,motorcycle
    5,5,airplane
    6,6,bus
    7,7,train
    8,8,truck
    9,9,boat
    10,10,traffic light
    11,11,fire hydrant
    13,12,stop sign
    14,13,parking meter
    15,14,bench
    16,15,bird
    17,16,cat
    18,17,dog
    19,18,horse
    20,19,sheep
    21,20,cow
    22,21,elephant
    23,22,bear
    24,23,zebra
    25,24,giraffe
    27,25,backpack
    28,26,umbrella
    31,27,handbag
    32,28,tie
    33,29,suitcase
    34,30,frisbee
    35,31,skis
    36,32,snowboard
    37,33,sports ball
    38,34,kite
    39,35,baseball bat
    40,36,baseball glove
    41,37,skateboard
    42,38,surfboard
    43,39,tennis racket
    44,40,bottle
    46,41,wine glass
    47,42,cup
    48,43,fork
    49,44,knife
    50,45,spoon
    51,46,bowl
    52,47,banana
    53,48,apple
    54,49,sandwich
    55,50,orange
    56,51,broccoli
    57,52,carrot
    58,53,hot dog
    59,54,pizza
    60,55,donut
    61,56,cake
    62,57,chair
    63,58,couch
    64,59,potted plant
    65,60,bed
    67,61,dining table
    70,62,toilet
    72,63,tv
    73,64,laptop
    74,65,mouse
    75,66,remote
    76,67,keyboard
    77,68,cell phone
    78,69,microwave
    79,70,oven
    80,71,toaster
    81,72,sink
    82,73,refrigerator
    84,74,book
    85,75,clock
    86,76,vase
    87,77,scissors
    88,78,teddy bear
    89,79,hair drier
    90,80,toothbrush
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81

    3.3 ReadTxt

    #-*- codeing = utf-8 -*-
    #@Function: 
    #@Time : 2022/5/8 19:25
    #@Author : yx
    #@File : ReadTxt.py
    #@Software : PyCharm
    
    
    import pandas as pd
    
    col_data = []
    test = pd.read_csv("labels.txt", header=None) # 这个是没有标题的文件
    for i in test.index:
        data = test.loc[i].values
        col_data.append(data[2])
    
    print(col_data)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    请添加图片描述

  • 相关阅读:
    mybatis
    陕西省助理评审申报,看这文章就够了
    在CSDN上挣点外快的小tips
    Node + Express 后台开发 —— 上传、下载和发布
    opencv-形态学处理
    rust编程-通用编程概念(chapter 3.4 & 3.5 注释和控制语句)
    原三高搜索条件-多选问题
    “元宇宙”最权威的解释来了!全国科技名词委研讨会达成共识
    下载文件时的文件名中文乱码问题,文件名丢失
    聊一下Glove
  • 原文地址:https://blog.csdn.net/qq_44747572/article/details/127604416