• YOLOV5学习笔记(五)——使用代码detect train讲解


    1 detect.py

    1. #!/usr/bin/env python
    2. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
    3. """
    4. Run inference on images, videos, directories, streams, etc.
    5. Usage - sources:
    6. $ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
    7. img.jpg # image
    8. vid.mp4 # video
    9. path/ # directory
    10. path/*.jpg # glob
    11. 'https://youtu.be/Zgi9g1ksQHc' # YouTube
    12. 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
    13. Usage - formats:
    14. $ python path/to/detect.py --weights yolov5s.pt # PyTorch
    15. yolov5s.torchscript # TorchScript
    16. yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
    17. yolov5s.xml # OpenVINO
    18. yolov5s.engine # TensorRT
    19. yolov5s.mlmodel # CoreML (macOS-only)
    20. yolov5s_saved_model # TensorFlow SavedModel
    21. yolov5s.pb # TensorFlow GraphDef
    22. yolov5s.tflite # TensorFlow Lite
    23. yolov5s_edgetpu.tflite # TensorFlow Edge TPU
    24. """
    25. import argparse #命令行解析模块
    26. import os
    27. import platform
    28. import sys
    29. from pathlib import Path
    30. import cv2
    31. import torch
    32. import torch.backends.cudnn as cudnn
    33. FILE = Path(__file__).resolve()
    34. ROOT = FILE.parents[0] # YOLOv5 root directory
    35. if str(ROOT) not in sys.path:
    36. sys.path.append(str(ROOT)) # add ROOT to PATH
    37. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
    38. from models.common import DetectMultiBackend
    39. from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
    40. from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
    41. increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
    42. from utils.plots import Annotator, colors, save_one_box
    43. from utils.torch_utils import select_device, time_sync
    44. @torch.no_grad()
    45. def run(
    46. weights=ROOT / 'yolov5s.pt', # model.pt path(s)
    47. source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
    48. data=ROOT / 'data/coco128.yaml', # dataset.yaml path
    49. imgsz=(640, 640), # inference size (height, width)
    50. conf_thres=0.25, # confidence threshold
    51. iou_thres=0.45, # NMS IOU threshold
    52. max_det=1000, # maximum detections per image
    53. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
    54. view_img=False, # show results
    55. save_txt=False, # save results to *.txt
    56. save_conf=False, # save confidences in --save-txt labels
    57. save_crop=False, # save cropped prediction boxes
    58. nosave=False, # do not save images/videos
    59. classes=None, # filter by class: --class 0, or --class 0 2 3
    60. agnostic_nms=False, # class-agnostic NMS
    61. augment=False, # augmented inference
    62. visualize=False, # visualize features
    63. update=False, # update all models
    64. project=ROOT / 'runs/detect', # save results to project/name
    65. name='exp', # save results to project/name
    66. exist_ok=False, # existing project/name ok, do not increment
    67. line_thickness=3, # bounding box thickness (pixels)
    68. hide_labels=False, # hide labels
    69. hide_conf=False, # hide confidences
    70. half=False, # use FP16 half-precision inference
    71. dnn=False, # use OpenCV DNN for ONNX inference
    72. ):
    73. source = str(source)
    74. save_img = not nosave and not source.endswith('.txt') # save inference images
    75. is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    76. is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
    77. webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)#是否要用电脑摄像头
    78. if is_url and is_file:
    79. source = check_file(source) # download
    80. # Directories
    81. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
    82. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
    83. # Load model
    84. device = select_device(device) #指定设备
    85. #加载模型
    86. model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    87. stride, names, pt = model.stride, model.names, model.pt
    88. imgsz = check_img_size(imgsz, s=stride) # 检查图像尺寸,确保能被32整除
    89. # Dataloader 加载数据
    90. if webcam: #电脑摄像头
    91. view_img = check_imshow()
    92. cudnn.benchmark = True # set True to speed up constant image size inference
    93. dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
    94. bs = len(dataset) # batch_size
    95. else: #数据集
    96. dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
    97. bs = 1 # batch_size
    98. vid_path, vid_writer = [None] * bs, [None] * bs
    99. # Run inference
    100. model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup 模型预热
    101. seen, windows, dt = 0, [], [0.0, 0.0, 0.0]
    102. '''
    103. path 图片视频路径
    104. img 进行resize+pad之后的图片,如(3,640,512)(c,h,w)
    105. img0s 原size图片,(1080,810,3)
    106. cap 读取图片时为None 读取视频时为视频源
    107. '''
    108. for path, im, im0s, vid_cap, s in dataset:
    109. t1 = time_sync() #获取时间
    110. im = torch.from_numpy(im).to(device) #转化为tensor格式
    111. im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
    112. im /= 255 # 0 - 255 to 0.0 - 1.0 #0~1中间的值
    113. if len(im.shape) == 3:
    114. im = im[None] # expand for batch dim
    115. t2 = time_sync()
    116. dt[0] += t2 - t1
    117. # Inference
    118. visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
    119. pred = model(im, augment=augment, visualize=visualize) #将图片传入模型网络
    120. t3 = time_sync()
    121. dt[1] += t3 - t2
    122. # NMS
    123. # pred :前向传播的输出
    124. # conf_thres 置信度阈值
    125. # classes 是否保留特定类别
    126. # 经过nms之后,预测框格式会变从 xywh变成xyxy
    127. pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
    128. dt[2] += time_sync() - t3
    129. # Second-stage classifier (optional)
    130. # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
    131. # Process predictions
    132. # 对每一张图片处理 i表示第几个框
    133. for i, det in enumerate(pred): # per image
    134. seen += 1
    135. if webcam: # batch_size >= 1
    136. p, im0, frame = path[i], im0s[i].copy(), dataset.count
    137. s += f'{i}: '
    138. else:
    139. p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
    140. p = Path(p) # to Path
    141. save_path = str(save_dir / p.name) # 保存图片路径
    142. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # 设置保存框坐标的txt文件
    143. s += '%gx%g ' % im.shape[2:] # 设置打印信息图片宽高
    144. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
    145. imc = im0.copy() if save_crop else im0 # for save_crop
    146. annotator = Annotator(im0, line_width=line_thickness, example=str(names))
    147. if len(det):
    148. # 调整预测框的坐标,基于resize+pad的图片坐标转化为原size图像上的坐标
    149. #此时坐标格式是xyxy
    150. det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
    151. # Print results 打印检测到结果的类别数目
    152. for c in det[:, -1].unique():
    153. n = (det[:, -1] == c).sum() # detections per class
    154. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
    155. # Write results 保存预测结果
    156. for *xyxy, conf, cls in reversed(det):
    157. if save_txt: # Write to file
    158. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 将xyxy格式转化为xywh,并除上wh做归一化,转为列表保存
    159. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
    160. with open(f'{txt_path}.txt', 'a') as f:
    161. f.write(('%g ' * len(line)).rstrip() % line + '\n')
    162. #在原图上画框
    163. if save_img or save_crop or view_img: # Add bbox to image
    164. c = int(cls) # integer class
    165. label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
    166. annotator.box_label(xyxy, label, color=colors(c, True))
    167. if save_crop:
    168. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
    169. # Stream results
    170. im0 = annotator.result()
    171. if view_img: #显示预测图片
    172. if platform.system() == 'Linux' and p not in windows:
    173. windows.append(p)
    174. cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
    175. cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
    176. cv2.imshow(str(p), im0)
    177. cv2.waitKey(1) # 1 millisecond
    178. # Save results (image with detections)
    179. if save_img: #保存预测后的图片
    180. if dataset.mode == 'image':
    181. cv2.imwrite(save_path, im0)
    182. else: # 'video' or 'stream'
    183. if vid_path[i] != save_path: # new video
    184. vid_path[i] = save_path
    185. if isinstance(vid_writer[i], cv2.VideoWriter):
    186. vid_writer[i].release() # release previous video writer
    187. if vid_cap: # video
    188. fps = vid_cap.get(cv2.CAP_PROP_FPS)
    189. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    190. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    191. else: # stream
    192. fps, w, h = 30, im0.shape[1], im0.shape[0]
    193. save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
    194. vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
    195. vid_writer[i].write(im0)
    196. # Print time (inference-only)
    197. LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
    198. # Print results
    199. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
    200. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
    201. if save_txt or save_img:
    202. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
    203. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    204. if update:
    205. strip_optimizer(weights) # update model (to fix SourceChangeWarning)
    206. def parse_opt():
    207. #建立参数解析对象parser
    208. parser = argparse.ArgumentParser()
    209. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
    210. parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
    211. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
    212. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') #网络输入图片的大小
    213. parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') #置信度阈值
    214. parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')#iou阈值
    215. parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
    216. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') #设置设备
    217. parser.add_argument('--view-img', action='store_true', help='show results') #是否展示预测之后的视频
    218. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')#是否将预测的框以txt保存
    219. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')#是否将置信度保存
    220. parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
    221. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    222. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3') #设置只保留某一部分类别
    223. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')#进行nms是否去除不同类别之间的框
    224. parser.add_argument('--augment', action='store_true', help='augmented inference')#推理时候进行多尺度翻转
    225. parser.add_argument('--visualize', action='store_true', help='visualize features')
    226. parser.add_argument('--update', action='store_true', help='update all models') #对所有模型进行strip_optimizer操作
    227. parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
    228. parser.add_argument('--name', default='exp', help='save results to project/name')
    229. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    230. parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
    231. parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    232. parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    233. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') #是否是半精度
    234. parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
    235. opt = parser.parse_args() #参数都会放到opt
    236. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
    237. print_args(FILE.stem, opt)
    238. return opt
    239. def main(opt):
    240. check_requirements(exclude=('tensorboard', 'thop'))
    241. run(**vars(opt))
    242. if __name__ == "__main__":
    243. opt = parse_opt()
    244. main(opt)

    2 train.py

    1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
    2. """
    3. Train a YOLOv5 model on a custom dataset.
    4. Models and datasets download automatically from the latest YOLOv5 release.
    5. Models: https://github.com/ultralytics/yolov5/tree/master/models
    6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
    7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
    8. Usage:
    9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
    10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
    11. """
    12. import argparse
    13. import math
    14. import os
    15. import random
    16. import sys
    17. import time
    18. from copy import deepcopy
    19. from datetime import datetime
    20. from pathlib import Path
    21. import numpy as np
    22. import torch
    23. import torch.distributed as dist
    24. import torch.nn as nn
    25. import yaml
    26. from torch.cuda import amp
    27. from torch.nn.parallel import DistributedDataParallel as DDP
    28. from torch.optim import SGD, Adam, AdamW, lr_scheduler
    29. from tqdm import tqdm
    30. FILE = Path(__file__).resolve()
    31. ROOT = FILE.parents[0] # YOLOv5 root directory
    32. if str(ROOT) not in sys.path:
    33. sys.path.append(str(ROOT)) # add ROOT to PATH
    34. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
    35. import val # for end-of-epoch mAP
    36. from models.experimental import attempt_load
    37. from models.yolo import Model
    38. from utils.autoanchor import check_anchors
    39. from utils.autobatch import check_train_batch_size
    40. from utils.callbacks import Callbacks
    41. from utils.datasets import create_dataloader
    42. from utils.downloads import attempt_download
    43. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
    44. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
    45. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
    46. print_args, print_mutation, strip_optimizer)
    47. from utils.loggers import Loggers
    48. from utils.loggers.wandb.wandb_utils import check_wandb_resume
    49. from utils.loss import ComputeLoss
    50. from utils.metrics import fitness
    51. from utils.plots import plot_evolve, plot_labels
    52. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
    53. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
    54. RANK = int(os.getenv('RANK', -1))
    55. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
    56. #**********************************************************************************************************************
    57. # *
    58. # 三、训练过程 *
    59. # *
    60. #**********************************************************************************************************************
    61. def train(hyp, # path/to/hyp.yaml or hyp dictionary
    62. opt,
    63. device,
    64. callbacks
    65. ):
    66. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
    67. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
    68. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
    69. #**********************************************************************************************************************
    70. # 3.1 权重、数据集、参数、路径初始化 *
    71. #**********************************************************************************************************************
    72. # Directories
    73. w = save_dir / 'weights' # weights dir
    74. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
    75. last, best = w / 'last.pt', w / 'best.pt' #保存权重的路径
    76. # Hyperparameters 超参数
    77. if isinstance(hyp, str):
    78. with open(hyp, errors='ignore') as f:
    79. hyp = yaml.safe_load(f) # load hyps dict
    80. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
    81. # Save run settings
    82. if not evolve:
    83. with open(save_dir / 'hyp.yaml', 'w') as f: #创建yaml文件
    84. yaml.safe_dump(hyp, f, sort_keys=False)
    85. with open(save_dir / 'opt.yaml', 'w') as f:
    86. yaml.safe_dump(vars(opt), f, sort_keys=False)
    87. # Loggers
    88. data_dict = None
    89. if RANK in [-1, 0]:
    90. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
    91. if loggers.wandb:
    92. data_dict = loggers.wandb.data_dict
    93. if resume:
    94. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
    95. # Register actions
    96. for k in methods(loggers):
    97. callbacks.register_action(k, callback=getattr(loggers, k))
    98. # Config
    99. plots = not evolve # create plots
    100. cuda = device.type != 'cpu' #选择设备
    101. init_seeds(1 + RANK) #随机化种子
    102. with torch_distributed_zero_first(LOCAL_RANK):
    103. data_dict = data_dict or check_dataset(data) # check if None 检查路径
    104. # 获取训练集、测试集图片路径
    105. train_path, val_path = data_dict['train'], data_dict['val']
    106. # 设置类别的数量nc 和对应的名字names
    107. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
    108. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
    109. # 确认name和nc的长度是想等的
    110. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
    111. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
    112. #**********************************************************************************************************************
    113. # 3.2 加载网络模型 *
    114. #**********************************************************************************************************************
    115. # Model
    116. check_suffix(weights, '.pt') # check weights 检查权重名
    117. pretrained = weights.endswith('.pt')
    118. if pretrained: #有预训练
    119. # 从谷歌云盘下载模型
    120. with torch_distributed_zero_first(LOCAL_RANK):
    121. weights = attempt_download(weights) # download if not found locally
    122. # 加载模型参数
    123. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
    124. # 加载模型
    125. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
    126. # 获得anchor
    127. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
    128. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
    129. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
    130. # 模型创建
    131. model.load_state_dict(csd, strict=False) # load
    132. # 如果pretrained为ture 则会少加载两个键对(anchors, anchor_grid)
    133. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
    134. else: #直接加载模型
    135. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
    136. # 3.2.1 设置模型输入
    137. #**********************************************************************************************************************
    138. # Freeze 冻结层
    139. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
    140. for k, v in model.named_parameters():
    141. v.requires_grad = True # train all layers
    142. if any(x in k for x in freeze):
    143. LOGGER.info(f'freezing {k}')
    144. v.requires_grad = False #不进行梯度计算
    145. # Image size
    146. gs = max(int(model.stride.max()), 32) # grid size (max stride)
    147. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
    148. # Batch size
    149. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
    150. batch_size = check_train_batch_size(model, imgsz)
    151. loggers.on_params_update({"batch_size": batch_size})
    152. # 3.2.2 优化器设置
    153. #**********************************************************************************************************************
    154. # Optimizer 优化器设置
    155. nbs = 64 # nominal batch size batch size为16 nbs为64 模型累计4次之后更新一次模型,变相扩大batch_size
    156. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
    157. # 根据accumulate设置权重衰减系数
    158. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
    159. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
    160. # 将模型分成三组(weight,bias,其他所有参数)进行优化
    161. g0, g1, g2 = [], [], [] # optimizer parameter groups
    162. for v in model.modules():
    163. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
    164. g2.append(v.bias)
    165. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
    166. g0.append(v.weight)
    167. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
    168. g1.append(v.weight)
    169. # 选用优化器,并设置pg0的优化方式
    170. if opt.optimizer == 'Adam':
    171. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
    172. elif opt.optimizer == 'AdamW':
    173. optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
    174. else:
    175. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
    176. # 设置weight的优化方式
    177. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
    178. # 设置biases的优化方式
    179. optimizer.add_param_group({'params': g2}) # add g2 (biases)
    180. # 打印优化信息
    181. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
    182. f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
    183. del g0, g1, g2
    184. # 3.2.3 模型其他功能选择
    185. #**********************************************************************************************************************
    186. # Scheduler 设置学习率的衰减 余弦退火调整
    187. if opt.cos_lr:
    188. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
    189. else:
    190. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
    191. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
    192. # EMA
    193. ema = ModelEMA(model) if RANK in [-1, 0] else None
    194. # Resume 断点续训
    195. start_epoch, best_fitness = 0, 0.0
    196. if pretrained:
    197. # Optimizer
    198. if ckpt['optimizer'] is not None:
    199. optimizer.load_state_dict(ckpt['optimizer'])
    200. best_fitness = ckpt['best_fitness']
    201. # EMA 指数移动平均:一种给予近期数据更高权重的平均方法
    202. if ema and ckpt.get('ema'):
    203. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
    204. ema.updates = ckpt['updates']
    205. # Epochs
    206. start_epoch = ckpt['epoch'] + 1
    207. if resume:
    208. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
    209. if epochs < start_epoch:
    210. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
    211. epochs += ckpt['epoch'] # finetune additional epochs
    212. del ckpt, csd
    213. # DP mode 是否有分布式训练
    214. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
    215. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
    216. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
    217. model = torch.nn.DataParallel(model)
    218. # SyncBatchNorm 跨卡同步
    219. if opt.sync_bn and cuda and RANK != -1:
    220. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
    221. LOGGER.info('Using SyncBatchNorm()')
    222. #**********************************************************************************************************************
    223. # 3.3 数据集预处理 *
    224. #**********************************************************************************************************************
    225. # Trainloader 数据处理过程
    226. # 3.3.1 创建数据集
    227. #**********************************************************************************************************************
    228. #创建训练集
    229. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
    230. hyp=hyp, augment=True, cache=None if opt.cache == 'val' else opt.cache,
    231. rect=opt.rect, rank=LOCAL_RANK, workers=workers,
    232. image_weights=opt.image_weights, quad=opt.quad,
    233. prefix=colorstr('train: '), shuffle=True)
    234. # 获取标签中最大的类别值与类别数做比较
    235. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
    236. nb = len(train_loader) # number of batches
    237. # 如果小于则出现问题
    238. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
    239. # Process 0
    240. if RANK in [-1, 0]:
    241. # 创建测试集
    242. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
    243. hyp=hyp, cache=None if noval else opt.cache,
    244. rect=True, rank=-1, workers=workers * 2, pad=0.5,
    245. prefix=colorstr('val: '))[0]
    246. if not resume:
    247. labels = np.concatenate(dataset.labels, 0) #目标框数,不是图片数
    248. # c = torch.tensor(labels[:, 0]) # classes
    249. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
    250. # model._initialize_biases(cf.to(device))
    251. if plots:
    252. plot_labels(labels, names, save_dir)
    253. # 3.3.1 计算anchor
    254. #**********************************************************************************************************************
    255. # Anchors 计算最佳anchor
    256. if not opt.noautoanchor:
    257. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
    258. model.half().float() # pre-reduce anchor precision
    259. callbacks.run('on_pretrain_routine_end')
    260. # DDP mode
    261. if cuda and RANK != -1:
    262. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
    263. # 3.3.2 根据数据分布设置类别训练权重
    264. #**********************************************************************************************************************
    265. # Model attributes 根据自己类别数设置分类损失的系数
    266. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
    267. hyp['box'] *= 3 / nl # scale to layers
    268. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
    269. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
    270. hyp['label_smoothing'] = opt.label_smoothing
    271. #设置模型的类别和超参数
    272. model.nc = nc # attach number of classes to model
    273. model.hyp = hyp # attach hyperparameters to model
    274. # 从训练的样本标签得到类别权重 和数量成反比
    275. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
    276. model.names = names #获取类别的名字
    277. #**********************************************************************************************************************
    278. # 3.4 模型训练 *
    279. #**********************************************************************************************************************
    280. # Start training 开始训练部分
    281. # 3.4.1 训练初始化
    282. #**********************************************************************************************************************
    283. t0 = time.time() #获取当前时间
    284. # 获取热身训练的迭代次数
    285. nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
    286. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
    287. last_opt_step = -1
    288. # 初始化 map和result
    289. maps = np.zeros(nc) # mAP per class
    290. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
    291. # 设置学习率衰减所进行到的轮次 目的是打断训练后,--resume也能接着衰减学习率训练
    292. scheduler.last_epoch = start_epoch - 1 # do not move
    293. # 通过torch自带的api设置混合精度训练
    294. scaler = amp.GradScaler(enabled=cuda) #训练开始时实例化一个GradScaler对象
    295. stopper = EarlyStopping(patience=opt.patience)
    296. compute_loss = ComputeLoss(model) # init loss class
    297. # 打印信息
    298. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
    299. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
    300. f"Logging results to {colorstr('bold', save_dir)}\n"
    301. f'Starting training for {epochs} epochs...')
    302. # 3.4.2 训练过程
    303. #**********************************************************************************************************************
    304. # 开始训练
    305. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
    306. model.train()
    307. # Update image weights (optional, single-GPU only)
    308. if opt.image_weights: #图片采样
    309. # 获取图片采样的权重
    310. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
    311. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
    312. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
    313. # Update mosaic border (optional)
    314. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
    315. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
    316. # 初始化训练时打印的平均损失信息
    317. mloss = torch.zeros(3, device=device) # mean losses
    318. if RANK != -1:
    319. train_loader.sampler.set_epoch(epoch)
    320. pbar = enumerate(train_loader)
    321. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
    322. if RANK in [-1, 0]:
    323. # 通过tqdm创建进度条,方便训练信息的展示
    324. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
    325. optimizer.zero_grad() #梯度训练
    326. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
    327. # 计算迭代次数
    328. ni = i + nb * epoch # number integrated batches (since train start)
    329. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
    330. # Warmup
    331. if ni <= nw:
    332. xi = [0, nw] # x interp
    333. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
    334. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
    335. for j, x in enumerate(optimizer.param_groups):
    336. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
    337. '''
    338. bias的学习率从0.1下降到基准学习率lr*lf(epoch)
    339. 其他的参数学习率从0增加到lr*lf(epoch)
    340. lf是余弦退火衰减函数
    341. '''
    342. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
    343. #动量momentum也从0.9慢慢变到hyp
    344. if 'momentum' in x:
    345. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
    346. # Multi-scale 多尺度训练 尺寸变为imgsz * 0.5, imgsz * 1.5 + gs随机选取尺寸
    347. if opt.multi_scale:
    348. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
    349. sf = sz / max(imgs.shape[2:]) # scale factor
    350. if sf != 1:
    351. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
    352. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
    353. # Forward 前向传播
    354. with amp.autocast(enabled=cuda):
    355. pred = model(imgs) # forward 把图片送入前向传播得到预测值
    356. # 计算loss
    357. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
    358. if RANK != -1:
    359. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
    360. if opt.quad:
    361. loss *= 4.
    362. # Backward 反向传播
    363. scaler.scale(loss).backward()# scale(loss)是为了梯度放大
    364. # Optimize
    365. if ni - last_opt_step >= accumulate: #模型反向传播accumulate之后再根据累计值更新一次参数
    366. scaler.step(optimizer) # optimizer.step
    367. scaler.update()
    368. optimizer.zero_grad() #梯度清0
    369. if ema:
    370. ema.update(model)
    371. last_opt_step = ni
    372. # Log 打印信息
    373. if RANK in [-1, 0]:
    374. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
    375. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
    376. # 通过进度条显示信息
    377. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
    378. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
    379. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
    380. if callbacks.stop_training:
    381. return
    382. # end batch ------------------------------------------------------------------------------------------------
    383. # Scheduler
    384. # batch结束后进行学习率衰减
    385. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
    386. scheduler.step() #对lr进行调整
    387. # 3.4.2 训练完成保存模型
    388. #**********************************************************************************************************************
    389. if RANK in [-1, 0]:
    390. # mAP
    391. callbacks.run('on_train_epoch_end', epoch=epoch)
    392. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
    393. # 判断是否是最后一轮
    394. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
    395. # 对测试集进行测试,计算指标
    396. if not noval or final_epoch: # Calculate mAP
    397. results, maps, _ = val.run(data_dict,
    398. batch_size=batch_size // WORLD_SIZE * 2,
    399. imgsz=imgsz,
    400. model=ema.ema,
    401. single_cls=single_cls,
    402. dataloader=val_loader,
    403. save_dir=save_dir,
    404. plots=False,
    405. callbacks=callbacks,
    406. compute_loss=compute_loss)
    407. # Update best mAP
    408. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
    409. if fi > best_fitness:
    410. best_fitness = fi
    411. log_vals = list(mloss) + list(results) + lr
    412. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
    413. # Save model 保存模型
    414. if (not nosave) or (final_epoch and not evolve): # if save
    415. ckpt = {'epoch': epoch,
    416. 'best_fitness': best_fitness,
    417. 'model': deepcopy(de_parallel(model)).half(),
    418. 'ema': deepcopy(ema.ema).half(),
    419. 'updates': ema.updates,
    420. 'optimizer': optimizer.state_dict(),
    421. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
    422. 'date': datetime.now().isoformat()}
    423. # Save last, best and delete
    424. torch.save(ckpt, last)
    425. if best_fitness == fi:
    426. torch.save(ckpt, best)
    427. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
    428. torch.save(ckpt, w / f'epoch{epoch}.pt')
    429. del ckpt
    430. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
    431. # Stop Single-GPU
    432. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
    433. break
    434. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
    435. # stop = stopper(epoch=epoch, fitness=fi)
    436. # if RANK == 0:
    437. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
    438. # Stop DPP
    439. # with torch_distributed_zero_first(RANK):
    440. # if stop:
    441. # break # must break all DDP ranks
    442. # end epoch ----------------------------------------------------------------------------------------------------
    443. # end training -----------------------------------------------------------------------------------------------------
    444. # 3.4.2 模型压缩内存释放
    445. #**********************************************************************************************************************
    446. if RANK in [-1, 0]:
    447. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
    448. for f in last, best:
    449. if f.exists():
    450. strip_optimizer(f) # strip optimizers 训练完成后会用strip_optimizer将优化器信息去除,并将32位变成16为浮点减少模型大小,提高前向推理速度
    451. if f is best:
    452. LOGGER.info(f'\nValidating {f}...')
    453. results, _, _ = val.run(data_dict,
    454. batch_size=batch_size // WORLD_SIZE * 2,
    455. imgsz=imgsz,
    456. model=attempt_load(f, device).half(),
    457. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
    458. single_cls=single_cls,
    459. dataloader=val_loader,
    460. save_dir=save_dir,
    461. save_json=is_coco,
    462. verbose=True,
    463. plots=True,
    464. callbacks=callbacks,
    465. compute_loss=compute_loss) # val best model with plots
    466. if is_coco:
    467. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
    468. callbacks.run('on_train_end', last, best, plots, epoch, results)
    469. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
    470. torch.cuda.empty_cache() #显存释放
    471. return results
    472. #**********************************************************************************************************************
    473. # *
    474. # 一、设置模型参数 *
    475. # *
    476. #**********************************************************************************************************************
    477. def parse_opt(known=False):
    478. parser = argparse.ArgumentParser()
    479. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')#模型参数初始化
    480. parser.add_argument('--cfg', type=str, default='/home/cxl/yolov5/src/yolov5/models/yolov5s.yaml', help='model.yaml path')#训练模型
    481. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
    482. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')#超参数设置,对模型微调
    483. parser.add_argument('--epochs', type=int, default=50)
    484. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
    485. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')#训练尺寸
    486. parser.add_argument('--rect', action='store_true', help='rectangular training')#减少图片填充
    487. parser.add_argument('--resume', nargs='?', const=True, default="", help='resume most recent training')#利用保存的模型继续训练
    488. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    489. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
    490. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')#铆点的模型画框
    491. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')#超参数净化
    492. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    493. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    494. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')#对上一轮训练不好的图片加一些权重
    495. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    496. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')#对图像变换
    497. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')#单类别多类别
    498. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
    499. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    500. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')#线程
    501. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')#保存路径
    502. parser.add_argument('--name', default='exp', help='save to project/name')
    503. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')#保存到新的文件夹
    504. parser.add_argument('--quad', action='store_true', help='quad dataloader')
    505. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')#训练学习率的设置,线性下降
    506. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    507. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
    508. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
    509. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
    510. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
    511. # Weights & Biases arguments
    512. parser.add_argument('--entity', default=None, help='W&B: Entity')
    513. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
    514. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
    515. #parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
    516. opt = parser.parse_known_args()[0] if known else parser.parse_args()
    517. return opt
    518. #**********************************************************************************************************************
    519. # *
    520. # 二、模型选择 *
    521. # *
    522. #**********************************************************************************************************************
    523. def main(opt, callbacks=Callbacks()):
    524. # Checks
    525. if RANK in [-1, 0]:
    526. print_args(FILE.stem, opt)
    527. check_git_status()
    528. check_requirements(exclude=['thop']) #检查代码是否是最新的
    529. # Resume python train.py --resume
    530. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run 是否是断点续训,如果是执行下面语句继续训练
    531. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path 获取runs文件夹中最近的last.pt
    532. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
    533. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
    534. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
    535. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
    536. LOGGER.info(f'Resuming training from {ckpt}')
    537. else:
    538. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
    539. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks 检查配置文件信息
    540. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
    541. if opt.evolve:
    542. if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
    543. opt.project = str(ROOT / 'runs/evolve')
    544. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
    545. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
    546. # DDP mode 选择设备
    547. device = select_device(opt.device, batch_size=opt.batch_size)
    548. if LOCAL_RANK != -1: #不是-1就是一个gpu
    549. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
    550. assert not opt.image_weights, f'--image-weights {msg}'
    551. assert not opt.evolve, f'--evolve {msg}'
    552. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
    553. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
    554. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
    555. torch.cuda.set_device(LOCAL_RANK)
    556. device = torch.device('cuda', LOCAL_RANK) #根据gpu编号选择设备
    557. # 初始化进程
    558. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
    559. # Train
    560. # 判断是否超参进化 默认flase
    561. if not opt.evolve:
    562. train(opt.hyp, opt, device, callbacks)
    563. if WORLD_SIZE > 1 and RANK == 0:
    564. LOGGER.info('Destroying process group... ') #创建tensorboard
    565. dist.destroy_process_group()
    566. # Evolve hyperparameters (optional)
    567. else: 超参进化 类似遗传算法
    568. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
    569. #超参进化列表,括号里分别为(突变规模、最小值、最大值)
    570. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
    571. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
    572. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
    573. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
    574. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
    575. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
    576. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
    577. 'box': (1, 0.02, 0.2), # box loss gain
    578. 'cls': (1, 0.2, 4.0), # cls loss gain
    579. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
    580. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
    581. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
    582. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
    583. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
    584. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
    585. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
    586. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
    587. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
    588. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
    589. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
    590. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
    591. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
    592. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
    593. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
    594. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
    595. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
    596. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
    597. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
    598. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
    599. with open(opt.hyp, errors='ignore') as f:
    600. hyp = yaml.safe_load(f) # load hyps dict
    601. if 'anchors' not in hyp: # anchors commented in hyp.yaml
    602. hyp['anchors'] = 3
    603. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
    604. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
    605. #超参进化的结果保存在以下文件中
    606. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
    607. if opt.bucket:
    608. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
    609. '''
    610. 默认进化300代
    611. 根据之前训练时的hyp来搞定一个base hyp再进行突变
    612. '''
    613. for _ in range(opt.evolve): # generations to evolve 进化代数
    614. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
    615. # 选择进化方式
    616. parent = 'single' # parent selection method: 'single' or 'weighted'
    617. # 加载evolve.txt
    618. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
    619. # 选取至多前5次进化的结果
    620. n = min(5, len(x)) # number of previous results to consider
    621. x = x[np.argsort(-fitness(x))][:n] # top n mutations
    622. # 根据results计算hyp权重
    623. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
    624. # 根据不同进化方式获得base hyp
    625. if parent == 'single' or len(x) == 1:
    626. # x = x[random.randint(0, n - 1)] # random selection
    627. x = x[random.choices(range(n), weights=w)[0]] # 1、weighted selection
    628. elif parent == 'weighted':
    629. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # 2、weighted combination
    630. # Mutate 超参数进化
    631. mp, s = 0.8, 0.2 # mutation probability, sigma
    632. npr = np.random
    633. npr.seed(int(time.time()))
    634. # 获取突变初始值
    635. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
    636. ng = len(meta)
    637. v = np.ones(ng)
    638. #设置突变
    639. while all(v == 1): # mutate until a change occurs (prevent duplicates)
    640. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
    641. # 将突变添加到base hyp上
    642. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
    643. hyp[k] = float(x[i + 7] * v[i]) # mutate
    644. # 修剪hyp在规定范围内
    645. # Constrain to limits
    646. for k, v in meta.items():
    647. hyp[k] = max(hyp[k], v[1]) # lower limit
    648. hyp[k] = min(hyp[k], v[2]) # upper limit
    649. hyp[k] = round(hyp[k], 5) # significant digits
    650. # Train mutation
    651. # 训练
    652. results = train(hyp.copy(), opt, device, callbacks)
    653. callbacks = Callbacks()
    654. # Write mutation results 打印图片结果
    655. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
    656. # Plot results
    657. plot_evolve(evolve_csv)
    658. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
    659. f"Results saved to {colorstr('bold', save_dir)}\n"
    660. f'Usage example: $ python train.py --hyp {evolve_yaml}')
    661. def run(**kwargs):
    662. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
    663. opt = parse_opt(True)
    664. for k, v in kwargs.items():
    665. setattr(opt, k, v)
    666. main(opt)
    667. return opt
    668. if __name__ == "__main__":
    669. opt = parse_opt()
    670. main(opt)
  • 相关阅读:
    大公司的Java面试题集
    [U3D ShaderGraph] 全面学习ShaderGraph节点 | 第一课 | 内置节点
    ArkTS声明式开发范式
    Zookeeper客户端Curator5.1节点事件监听CuratorCache用法
    【LeetCode】67. 二进制求和
    【正点原子STM32连载】第二十四章 内存保护(MPU)实验 摘自【正点原子】MiniPro STM32H750 开发指南_V1.1
    BUUCTF [BJDCTF2020]一叶障目 1
    《操作系统真象还原》 第 01 章 部署工作环境 学习笔记
    2024上海国际合成生物学与绿色生物制造展览会8月7-9号上海举办
    [附源码]java毕业设计基于JavaEE的在线学习平台
  • 原文地址:https://blog.csdn.net/HUASHUDEYANJING/article/details/126158028