• 手把手教程:RT-DETR如何训练自己的数据集 | NEU-DET钢材缺陷检测


    🚀🚀🚀本文内容:1)RT-DETR原理介绍;2)RT-DETR如何训练自己的数据集

     🚀🚀🚀RT-DETR改进创新专栏:http://t.csdnimg.cn/vuQTz

    学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研;

    RT-DETR模型创新优化,涨点技巧分享,科研小助手;

    目录

    1.RT-DETR介绍

    2.如何训练 RT-DETR模型

    2.1数据集介绍

    2.2配置NEU-DET.yaml

    2.3 超参数修改ultralytics/cfg/default.yaml

    2.4如何开启训练

    2.5 训练正式开始

    3.RT-DETR训练结果可视化分析


    1.RT-DETR介绍

    论文: https://arxiv.org/pdf/2304.08069.pdf

    摘要: RT-DETR是第一个实时端到端目标检测器。具体而言,我们设计了一个高效的混合编码器,通过解耦尺度内交互和跨尺度融合来高效处理多尺度特征,并提出了IoU感知的查询选择机制,以优化解码器查询的初始化。此外,RT-DETR支持通过使用不同的解码器层来灵活调整推理速度,而不需要重新训练,这有助于实时目标检测器的实际应用。RT-DETR-L在COCO val2017上实现了53.0%的AP,在T4 GPU上实现了114FPS,RT-DETR-X实现了54.8%的AP和74FPS,在速度和精度方面都优于相同规模的所有YOLO检测器。RT-DETR-R50实现了53.1%的AP和108FPS,RT-DETR-R101实现了54.3%的AP和74FPS,在精度上超过了全部使用相同骨干网络的DETR检测器。

     

    YOLO的问题点是什么?

     YOLO 检测器有个较大的待改进点是需要 NMS 后处理,其通常难以优化且不够鲁棒,因此检测器的速度存在延迟。为避免该问题,我们将目光移向了不需要 NMS 后处理的 DETR,一种基于 Transformer 的端到端目标检测器。然而,相比于 YOLO 系列检测器,DETR 系列检测器的速度要慢的多,这使得"无需 NMS "并未在速度上体现出优势。上述问题促使我们针对实时的端到端检测器进行探索,旨在基于 DETR 的优秀架构设计一个全新的实时检测器,从根源上解决 NMS 对实时检测器带来的速度延迟问题。

    NMS 是目标检测领域常用的后处理技术,用于去除检测器产生的重叠较多的检测框,其包含两个超参数:置信度阈值和 IoU 阈值。具体来说,低于置信度阈值的框被直接过滤,并且如果两个检测框的交并比大于 IoU 阈值,那么其中置信度低的框会被滤除。该过程迭代执行,直到所有类别都被处理完毕。因此,NMS 算法的执行时间取决于预测框数量和上述两个阈值。为了更好地说明这一点,我们使用 YOLOv5 (anchor-based) 和 YOLOv8 (anchor-free)  进行了统计和实测,测量指标包括不同置信度阈值下剩余的检测框的数量,以及在不同的超参数组合下检测器在 COCO 验证集上的精度和 NMS 的执行时间。实验结果表明,NMS 不仅会延迟推理速度,并且不够鲁棒,需要挑选合适的超参数才能达到最优精度。这一实验结果有力证明设计一种实时的端到端检测器是具有重要意义的。

    RT-DETR模型结构

    (1)Backbone: 采用了经典的ResNet和百度自研的HGNet-v2两种,backbone是可以Scaled,HGNetv2的L和X两个版本,分别对标经典的ResNet50和ResNet101,不同于DINO等DETR类检测器使用最后4个stage输出,RT-DETR为了提速只需要最后3个,这样也符合YOLO的风格;

    (2) Neck: 起名为HybridEncoder,其实是相当于DETR中的Encoder,其也类似于经典检测模型模型常用的FPN,论文里分析了Encoder计算量是比较冗余的,作者解耦了基于Transformer的这种全局特征编码,设计了AIFI (尺度内特征交互)和 CCFM(跨尺度特征融合)结合的新的高效混合编码器也就是 Efficient Hybrid Encoder ,此外把encoder_layer层数由6减小到1层,并且由几个通道维度区分L和X两个版本,配合CCFM中RepBlock数量一起调节宽度深度实现Scaled RT-DETR;

    颈部

    RT-DETR采用了一层Transformer的Encoder,只处理主干网络输出的 S5 特征,即AIFI(Attention-based Intra-scale Feature Interaction)模块。

    实验结果

     RT-DETR-R50 在 COCO val2017 上的精度为 53.1% AP,在 T4 GPU 上的 FPS 为 108,RT-DETR-R101 的精度为 54.3% AP,FPS 为 74。总结来说,RT-DETR 比具有相同 backbone 的 DETR 系列检测器有比较显著的精度提升和速度提升。

    2.如何训练 RT-DETR模型

    2.1数据集介绍

    经典的NEU-DET数据集,数据集大小1800张,按照train:val:test  7:2:1随机划分

    2.2配置NEU-DET.yaml

    1. path: ./ultralytics-rt-detr/data/NEU-DET # dataset root dir
    2. train: train.txt # train images
    3. val: val.txt # val images
    4. # number of classes
    5. nc: 6
    6. # class names
    7. names:
    8. 0: crazing
    9. 1: inclusion
    10. 2: patches
    11. 3: pitted_surface
    12. 4: rolled-in_scale
    13. 5: scratches

    2.3 超参数修改ultralytics/cfg/default.yaml

    初版算法选择默认参数即可

    1. # Ultralytics YOLO 🚀, AGPL-3.0 license
    2. # Default training settings and hyperparameters for medium-augmentation COCO training
    3. task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
    4. mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
    5. # Train settings -------------------------------------------------------------------------------------------------------
    6. model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
    7. data: # (str, optional) path to data file, i.e. coco128.yaml
    8. epochs: 100 # (int) number of epochs to train for
    9. patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
    10. batch: 16 # (int) number of images per batch (-1 for AutoBatch)
    11. imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
    12. save: True # (bool) save train checkpoints and predict results
    13. save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
    14. cache: True # (bool) True/ram, disk or False. Use cache for data loading
    15. device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
    16. workers: 0 # (int) number of worker threads for data loading (per RANK if DDP)
    17. project: # (str, optional) project name
    18. name: # (str, optional) experiment name, results saved to 'project/name' directory
    19. exist_ok: False # (bool) whether to overwrite existing experiment
    20. pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
    21. optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
    22. verbose: True # (bool) whether to print verbose output
    23. seed: 0 # (int) random seed for reproducibility
    24. deterministic: True # (bool) whether to enable deterministic mode
    25. single_cls: False # (bool) train multi-class data as single-class
    26. rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
    27. cos_lr: False # (bool) use cosine learning rate scheduler
    28. close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable)
    29. resume: False # (bool) resume training from last checkpoint
    30. amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
    31. fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
    32. profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
    33. freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
    34. # Segmentation
    35. overlap_mask: True # (bool) masks should overlap during training (segment train only)
    36. mask_ratio: 4 # (int) mask downsample ratio (segment train only)
    37. # Classification
    38. dropout: 0.0 # (float) use dropout regularization (classify train only)
    39. # Val/Test settings ----------------------------------------------------------------------------------------------------
    40. val: True # (bool) validate/test during training
    41. split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
    42. save_json: False # (bool) save results to JSON file
    43. save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
    44. conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
    45. iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
    46. max_det: 300 # (int) maximum number of detections per image
    47. half: False # (bool) use half precision (FP16)
    48. dnn: False # (bool) use OpenCV DNN for ONNX inference
    49. plots: True # (bool) save plots during train/val
    50. # Prediction settings --------------------------------------------------------------------------------------------------
    51. source: # (str, optional) source directory for images or videos
    52. show: False # (bool) show results if possible
    53. save_txt: False # (bool) save results as .txt file
    54. save_conf: False # (bool) save results with confidence scores
    55. save_crop: False # (bool) save cropped images with results
    56. show_labels: True # (bool) show object labels in plots
    57. show_conf: True # (bool) show object confidence scores in plots
    58. vid_stride: 1 # (int) video frame-rate stride
    59. stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
    60. line_width: # (int, optional) line width of the bounding boxes, auto if missing
    61. visualize: False # (bool) visualize model features
    62. augment: False # (bool) apply image augmentation to prediction sources
    63. agnostic_nms: False # (bool) class-agnostic NMS
    64. classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
    65. retina_masks: False # (bool) use high-resolution segmentation masks
    66. boxes: True # (bool) Show boxes in segmentation predictions
    67. # Export settings ------------------------------------------------------------------------------------------------------
    68. format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
    69. keras: False # (bool) use Kera=s
    70. optimize: False # (bool) TorchScript: optimize for mobile
    71. int8: False # (bool) CoreML/TF INT8 quantization
    72. dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
    73. simplify: False # (bool) ONNX: simplify model
    74. opset: # (int, optional) ONNX: opset version
    75. workspace: 4 # (int) TensorRT: workspace size (GB)
    76. nms: False # (bool) CoreML: add NMS
    77. # Hyperparameters ------------------------------------------------------------------------------------------------------
    78. lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
    79. lrf: 0.01 # (float) final learning rate (lr0 * lrf)
    80. momentum: 0.937 # (float) SGD momentum/Adam beta1
    81. weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
    82. warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
    83. warmup_momentum: 0.8 # (float) warmup initial momentum
    84. warmup_bias_lr: 0.1 # (float) warmup initial bias lr
    85. box: 7.5 # (float) box loss gain
    86. cls: 0.5 # (float) cls loss gain (scale with pixels)
    87. dfl: 1.5 # (float) dfl loss gain
    88. pose: 12.0 # (float) pose loss gain
    89. kobj: 1.0 # (float) keypoint obj loss gain
    90. label_smoothing: 0.0 # (float) label smoothing (fraction)
    91. nbs: 64 # (int) nominal batch size
    92. hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
    93. hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
    94. hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
    95. degrees: 0.0 # (float) image rotation (+/- deg)
    96. translate: 0.1 # (float) image translation (+/- fraction)
    97. scale: 0.5 # (float) image scale (+/- gain)
    98. shear: 0.0 # (float) image shear (+/- deg)
    99. perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
    100. flipud: 0.0 # (float) image flip up-down (probability)
    101. fliplr: 0.5 # (float) image flip left-right (probability)
    102. mosaic: 1.0 # (float) image mosaic (probability)
    103. mixup: 0.0 # (float) image mixup (probability)
    104. copy_paste: 0.0 # (float) segment copy-paste (probability)
    105. # Custom config.yaml ---------------------------------------------------------------------------------------------------
    106. cfg: # (str, optional) for overriding defaults.yaml
    107. # Tracker settings ------------------------------------------------------------------------------------------------------
    108. tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]

    2.4如何开启训练

    1. from ultralytics.cfg import entrypoint
    2. arg="yolo detect train model=rtdetr-l.yaml data=ultralytics/cfg/datasets/NEU-DET.yaml"
    3. entrypoint(arg)

    2.5 训练正式开始

    3.RT-DETR训练结果可视化分析

    1. rtdetr-l summary: 498 layers, 31996070 parameters, 0 gradients
    2. Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 31/31 [00:14<00:00, 2.13it/s]
    3. all 486 1069 0.696 0.673 0.7 0.398
    4. crazing 486 149 0.576 0.208 0.303 0.113
    5. inclusion 486 222 0.721 0.78 0.805 0.437
    6. patches 486 243 0.798 0.881 0.894 0.575
    7. pitted_surface 486 130 0.758 0.754 0.78 0.479
    8. rolled-in_scale 486 171 0.559 0.538 0.533 0.229
    9. scratches 486 154 0.761 0.877 0.885 0.553

  • 相关阅读:
    Springboot 服务 禁止设置启动server端口使用
    java计算机毕业设计能源类网站平台源码+系统+数据库+lw文档+mybatis+运行部署
    postgreSQL触发器
    最新版Git安装指南使用指南
    西交软件915历年真题_编程题汇总与分析
    LLM 技术图谱(LLM Tech Map)& Kubernetes (K8s) 与AIGC的结合应用
    PostgreSQL 管理PG 的 4个 自制小脚本
    Docker容器化技术(使用Dockerfile制作镜像)
    【项目实战】自备相机+IMU跑通Vins-Mono记录
    Flutter 常见错误记录总结
  • 原文地址:https://blog.csdn.net/CV_20231007/article/details/134281222