• 基于megengine实现YoloX【附部分源码】



    前言

    本文以megengine为深度学习框架实现yolox,并训练自己的数据集及可进行测试coco数据集


    一、YOLOX网上说法

    YOLOX是一个高性能的anchor-free版本,它超过了YOLOv3到YOLOv5,并且支持MegEngine、ONNX、TensorRT、ncnn和OpenVINO,为大家提供了加速和部署的支持。

    在原始论文的题目可以看出,YOLOX在2021年超过了所有YOLO系列的版本

    论文地址:https://arxiv.org/pdf/2107.08430.pdf

    论文提供的效果图:
    在这里插入图片描述
    可见yolox远超yolo5及efficientdet


    二、YOLOX代码实现

    本工程代码结构如下:
    在这里插入图片描述
    BaseModel:存放COCO训练好的权重文件
    core:一些评价指标的文件
    data:数据加载器
    nets:网络结构
    utils:一些重要的数据处理的文件‘
    Ctu_Yolox.py:主函数入口

    1.网络结构实现

    这里实现的yolox版本包括:
    yolox_darknet、yolox_nano、yolox_tiny、yolox_s、yolox_m、yolox_l、yolox_x

    yolox主模型入口

    class YOLOX(M.Module):
        yolox_model={
            'yolox_x':{
                'depth':1.33,
                'width':1.25
            },
            'yolox_l':{
                'depth':1.0,
                'width':1.0
            },
            'yolox_m':{
                'depth':0.67,
                'width':0.75
            },
            'yolox_s':{
                'depth':0.33,
                'width':0.50
            },
            'yolox_tiny':{
                'depth':0.33,
                'width':0.375
            },
            'yolox_nano':{
                'depth':0.33,
                'width':0.25
            },
            'yolox_darknet':{
                'depth':1.0,
                'width':1.0
            },
        }
        def __init__(self, num_classes=80,network_name='yolox_s'):
            super().__init__()
            self.network_name=network_name
            if self.network_name !='yolox_darknet':
                in_channels = [256, 512, 1024]
                self.backbone = YOLOPAFPN(self.yolox_model[self.network_name]['depth'], self.yolox_model[self.network_name]['width'], in_channels=in_channels)
                self.head = YOLOXHead(num_classes, self.yolox_model[self.network_name]['width'], in_channels=in_channels)
            else:
                self.backbone = YOLOFPN()
                self.head = YOLOXHead(num_classes, self.yolox_model[self.network_name]['width'], in_channels=[128, 256, 512], act="lrelu")
        def forward(self, x, targets=None):
            fpn_outs = self.backbone(x)
    
            if self.training:
                assert targets is not None
                loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(fpn_outs, targets, x)
                outputs = {
                    "total_loss": loss,
                    "iou_loss": iou_loss,
                    "l1_loss": l1_loss,
                    "conf_loss": conf_loss,
                    "cls_loss": cls_loss,
                    "num_fg": num_fg,
                }
            else:
                outputs = self.head(fpn_outs)
    
            return outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59

    2.代码框架实现

    这里主要是Ctu_Yolox.py代码实现

    GPU设置

    这里只需要修改环境变量即可

    os.environ['CUDA_VISIBLE_DEVICES']=USEGPU
    
    • 1

    数据预处理

    这里对数据进行归一化处理

    self.rgb_means = (0.485, 0.456, 0.406)
    self.std = (0.229, 0.224, 0.225)
    
    • 1
    • 2

    数据加载器

    这里使用的是VOC数据集

    def CreateDataList_Detection(IMGDir,XMLDir,train_split):
        ImgList = os.listdir(IMGDir)
        XmlList = os.listdir(XMLDir)
        classes = []
        dataList=[]
        for each_jpg in ImgList:
            each_xml = each_jpg.split('.')[0] + '.xml'
            if each_xml in XmlList:
                dataList.append([os.path.join(IMGDir,each_jpg),os.path.join(XMLDir,each_xml)])
                with open(os.path.join(XMLDir,each_xml), "r", encoding="utf-8") as in_file:
                    tree = ET.parse(in_file)
                    root = tree.getroot()
                    for obj in root.iter('object'):
                        cls = obj.find('name').text
                        if cls not in classes:
                            classes.append(cls)
        random.shuffle(dataList)
        if train_split <=0 or train_split >=1:
            train_data_list = dataList
            val_data_list = dataList
        else:
            train_data_list = dataList[:int(len(dataList)*train_split)]
            val_data_list = dataList[int(len(dataList)*train_split):]
        return train_data_list, val_data_list, classes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    数据迭代器

    class VOCDetection(Dataset):
        def __init__(self, data_list,classes_names, img_size=(640, 640), preproc=None):
            super().__init__(img_size)
            self.data_list = data_list
            self.image_set = classes_names
            self.img_size = img_size
            self.preproc = preproc
            self.target_transform = AnnotationTransform(classes_names)
            self.ids = list()
    
        def __len__(self):
            return len(self.data_list)
    
        def load_anno(self, index):
            target = ET.parse(self.data_list[index][1]).getroot()
            if self.target_transform is not None:
                target = self.target_transform(target)
            return target
    
        def pull_item(self, index):
            img = cv2.imread(self.data_list[index][0], cv2.IMREAD_COLOR)
            height, width, _ = img.shape
            target = self.load_anno(index)
            img_info = (width, height)
            return img, target, img_info, index
    
        @Dataset.resize_getitem
        def __getitem__(self, index):
            img, target, img_info, img_id = self.pull_item(index)
    
            if self.preproc is not None:
                img, target = self.preproc(img, target, self.input_dim)
    
            return img, target, img_info, img_id
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    模型构建

    self.network_name = network_name
    self.model = YOLOX(num_classes=len(self.classes_names),network_name=self.network_name)
    
    • 1
    • 2

    这里的networ_name就是制定使用的yolox的版本


    模型训练

    for self.epoch in range(0, TrainNum):
        if self.epoch + 1 == (TrainNum - no_aug_epochs) or (not self.aug_Flag):
            self.train_loader.close_mosaic()
            self.model.head.use_l1 = True
    
        for self.iter in range(self.max_iter):
            iter_start_time = time.time()
            inps, targets = self.prefetcher.next()
            inps, targets = mge.tensor(inps), mge.tensor(targets)
            data_end_time = time.time()
    
            with self.grad_manager:
                outputs = self.model(inps, targets)
                loss = outputs["total_loss"]
                self.grad_manager.backward(loss)
    
            self.optimizer.step().clear_grad()
    
            if self.emaFlag:
                self.ema_model.update()
    
            lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = lr
    
            iter_end_time = time.time()
            self.meter.update(
                iter_time=iter_end_time - iter_start_time,
                data_time=data_end_time - iter_start_time,
                lr=lr,
                **outputs,
            )
            
            left_iters = self.max_iter * TrainNum - (self.progress_in_iter + 1)
            eta_seconds = self.meter["iter_time"].global_avg * left_iters
            eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
    
            progress_str = "epoch: {}/{}, iter: {}/{}".format(
                self.epoch + 1, TrainNum, self.iter + 1, self.max_iter
            )
            loss_meter = self.meter.get_filtered_meter("loss")
            loss_str_list = []
            for k, v in loss_meter.items():
                single_loss_str = "{}: {:.6f}".format(k, float(v.latest))
    
                loss_str_list.append(single_loss_str)
    
            loss_str = ", ".join(loss_str_list)
    
            time_meter = self.meter.get_filtered_meter("time")
            time_str = ", ".join([
                "{}: {:.3f}s".format(k, float(v.avg))
                for k, v in time_meter.items()
            ])
    
            print(
                "{}, {}, {}, lr: {:.3e}".format(
                    progress_str,
                    time_str,
                    loss_str,
                    self.meter["lr"].latest,
                )
                + (", size: {:d}, {}".format(self.image_size, eta_str))
            )
            self.meter.clear_meters()
    
        mge.save(
            {"epoch": 0, "state_dict": self.model.state_dict(),'classes_names':self.classes_names,'image_size':self.image_size,'network_name':self.network_name}, os.path.join(ModelPath,'final.pkl'),
        )
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    模型预测

    def predict(self,img,score=0.35,nms=0.35):
        img_info = {"id": 0}
        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img
        img, ratio = preprocess(img, (self.image_size,self.image_size), self.rgb_means, self.std)
        img_info["ratio"] = ratio
        img = F.expand_dims(mge.tensor(img), 0)
        outputs = ctu.model(img)
        outputs = postprocess(outputs, len(self.classes_names), 0.01, nms)
        # huizhi
        ratio = img_info["ratio"]
        img = img_info["raw_img"]
        output = outputs[0].numpy()
        bboxes = output[:, 0:4] / ratio
        cls = output[:, 6]
        scores = output[:, 4] * output[:, 5]
        vis_res = vis(img, bboxes, scores, cls, score, self.classes_names)
        return vis_res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    三、程序主入口

    因为本人的习惯,调用方式直接以新建类的方式,因此调用方式都比较简单,主函数内容只有几行即可实现模型训练和模型预测

    if __name__ == "__main__":
        # ctu = Ctu_YoloX(USEGPU='-1',image_size=416)
        # ctu.InitModel(r'./DataSet_Detection_YaoPian',train_split=0.9,batch_size = 1,Pre_Model=None,network_name='yolox_nano')
        # ctu.train(TrainNum=400,learning_rate=0.0001, ModelPath='result_Model')
    
        ctu = Ctu_YoloX(USEGPU='-1')
        ctu.LoadModel('./BaseModel/yolox_tiny.pkl')
        for root, dirs, files in os.walk(r'./BaseModel/test_image'):
            for f in files:
                img_cv = cv2.imread(os.path.join(root, f))
                if img_cv is None:
                    continue
                res_img = ctu.predict(img_cv)
                cv2.imshow("result", res_img)
                cv2.waitKey()
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    四、效果展示

    自定义数据集

    待定

    COCO数据集

    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    上周热点回顾(7.10-7.16)
    Visual Studio 中的键盘快捷方式大全
    HarmonyOS应用开发者高级认证
    shell脚本之环境变量
    【Embedded System】裸机接口开发
    从libc-2.27.so[7fd68b298000+1e7000]崩溃回溯程序段错误segfault
    redis之分片集群
    net mvc中使用vue自定义组件遇到的坑
    [Vue]缓存路由组件 & activated()与deactivated()
    在Unity中模拟汽车的移动
  • 原文地址:https://blog.csdn.net/ctu_sue/article/details/127609536