• 替换SlowFast中Detectron2为Yolov8


    一 需求

    FaceBookReserch中SlowFast源码中检测框是用Detectron2进行目标检测,本文想实现用yolov8替换detectron2
    
    • 1

    二 实施方案

    首先,yolov8 支持有自定义库ultralytics(仅支持yolov8),安装对应库

    pip install ultralytics
    
    • 1

    源码中slowfast/visualization.py 43行中

    if cfg.DETECTION.ENABLE:
           self.object_detector = Detectron2Predictor(cfg, gpu_id=self.gpu_id)
    
    • 1
    • 2

    根据ultralytics文档进行定义
    创建对应YOLOPredictor类(加入了检测框及其标签,具体见前一篇文章)

    class YOLOPredictor:
    
        def __init__(self, cfg, gpu_id=None):
            # 加载预训练的 YOLOv8n 模型
            self.model = YOLO('/root/autodl-tmp/data/runs/detect/train/weights/best.pt')
            self.detect_names, _, _ = get_class_names(cfg.DEMO.Detect_File_Path, None, None)
    
        def __call__(self, task):
            """
            Return bounding boxes predictions as a tensor.
            Args:
                task (TaskInfo object): task object that contain
                    the necessary information for action prediction. (e.g. frames)
            Returns:
                task (TaskInfo object): the same task info object but filled with
                    prediction values (a tensor) and the corresponding boxes for
                    action detection task.
            """
            # """得到预测置信度"""
            # scores = outputs["instances"].scores[mask].tolist()
            # """获取类别标签"""
            # pred_labels = outputs["instances"].pred_classes[mask]
            # pred_labels = pred_labels.tolist()
            # """进行标签匹配"""
            # for i in range(len(pred_labels)):
            #     pred_labels[i] = self.detect_names[pred_labels[i]]
            # preds = [
            #     "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores, pred_labels)
            # ]
            # """加入预测标签"""
            # task.add_detect_preds(preds)
            # task.add_bboxes(pred_boxes)
            middle_frame = task.frames[len(task.frames) // 2]
            outputs = self.model(middle_frame)
            boxes = outputs[0].boxes
            mask = boxes.conf >= 0.5
            pred_boxes = boxes.xyxy[mask]
            scores = boxes.conf[mask].tolist()
            pred_labels = boxes.cls[mask].to(torch.int)
            pred_labels = pred_labels.tolist()
            for i in range(len(pred_labels)):
                pred_labels[i] = self.detect_names[pred_labels[i]]
            preds = [
                "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores, pred_labels)
            ]
            """加入预测标签"""
            task.add_detect_preds(preds)
            task.add_bboxes(pred_boxes)
    
            return task
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
  • 相关阅读:
    【TypeScript】语法详解 - 类型操作的补充
    【Spring-5】AbstractBeanFactory#doGetBean创建实例完整流程
    [附源码]JAVA毕业设计基于web的公益募捐网站(系统+LW)
    vue+golang上传微信头像
    发布vue3组件到npm
    青阳网络文件传输系统 kiftd 1.1.0 正式发布!
    Druid配置参数详解-maxEvictableIdleTimeMillis,minEvictableIdleTimeMillis(转)
    【毕业设计】Java 基于微信小程序的药店管理系统
    ffmpeg 安装教程
    02 项目设置
  • 原文地址:https://blog.csdn.net/qq_59159431/article/details/134484535