• FCOS论文复现:通用物体检测算法


    摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

    本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

    FCOS:Fully Convolutional One-Stage Object Detection

    本案例代码是FCOS论文复现的体验案例

    此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

    具体的算法介绍:https://marketplace.huaweicloud.com/markets/aihub/modelhub/detail/?id=ce7acc40-0540-45c9-a0c6-e2fda8d1ac7e

    注意事项:

    1.本案例使用框架: PyTorch1.0.0

    2.本案例使用硬件: GPU

    3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

    1.数据和代码下载

    复制代码
    import os
    import moxing as mox
    # 数据代码下载
    mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
    # 解压缩
    os.system('unzip  FCOS.zip -d ./')
    复制代码

    2.模型训练

    2.1依赖库安装及加载

    复制代码
    """
    Basic training script for PyTorch
    """
    # Set up custom environment before nearly anything else is imported
    # NOTE: this should be the first import (no not reorder)
    import os
    import argparse
    import torch
    import shutil
    src_dir = './FCOS/'
    os.chdir(src_dir)
    os.system('pip install -r ./pip-requirements.txt')
    os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
    os.system('python setup.py build develop')
    from framework.utils.env import setup_environment
    from framework.config import cfg
    from framework.data import make_data_loader
    from framework.solver import make_lr_scheduler
    from framework.solver import make_optimizer
    from framework.engine.inference import inference
    from framework.engine.trainer import do_train
    from framework.modeling.detector import build_detection_model
    from framework.utils.checkpoint import DetectronCheckpointer
    from framework.utils.collect_env import collect_env_info
    from framework.utils.comm import synchronize, \
     get_rank, is_pytorch_1_1_0_or_later
    from framework.utils.logger import setup_logger
    from framework.utils.miscellaneous import mkdir
    复制代码

    2.2训练函数

    复制代码
    def train(cfg, local_rank, distributed, new_iteration=False):
        model = build_detection_model(cfg)
        device = torch.device(cfg.MODEL.DEVICE)
        model.to(device)
     if cfg.MODEL.USE_SYNCBN:
     assert is_pytorch_1_1_0_or_later(), \
     "SyncBatchNorm is only available in pytorch >= 1.1.0"
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)
     if distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank,
     # this should be removed if we update BatchNorm stats
     broadcast_buffers=False,
     )
        arguments = {}
        arguments["iteration"] = 0
     output_dir = cfg.OUTPUT_DIR
     save_to_disk = get_rank() == 0
     checkpointer = DetectronCheckpointer(
     cfg, model, optimizer, scheduler, output_dir, save_to_disk
     )
     print(cfg.MODEL.WEIGHT)
     extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)
     print(extra_checkpoint_data)
     arguments.update(extra_checkpoint_data)
     if new_iteration:
            arguments["iteration"] = 0
     data_loader = make_data_loader(
     cfg,
     is_train=True,
     is_distributed=distributed,
     start_iter=arguments["iteration"],
     )
     do_train(
            model,
     data_loader,
            optimizer,
            scheduler,
     checkpointer,
            device,
            arguments,
     )
     return model
    复制代码

    2.3设置参数,开始训练

    复制代码
    def main():
        parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
     parser.add_argument(
     '--train_url',
            default='./outputs',
     type=str,
     help='the path to save training outputs'
     )
     parser.add_argument(
     "--config-file",
            default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",
     metavar="FILE",
     help="path to config file",
     type=str,
     )
     parser.add_argument("--local_rank", type=int, default=0)
     parser.add_argument('--train_iterations', default=0, type=int)
     parser.add_argument('--warmup_iterations', default=500, type=int)
     parser.add_argument('--train_batch_size', default=8, type=int)
     parser.add_argument('--solver_lr', default=0.01, type=float)
     parser.add_argument('--decay_steps', default='120000,160000', type=str)
     parser.add_argument('--new_iteration',default=False, action='store_true')
     args, unknown = parser.parse_known_args()
     cfg.merge_from_file(args.config_file)
     # load the model trained on MS-COCO
     if args.train_iterations > 0:
     cfg.SOLVER.MAX_ITER = args.train_iterations
     if args.warmup_iterations > 0:
     cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations
     if args.train_batch_size > 0:
     cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size
     if args.solver_lr > 0:
     cfg.SOLVER.BASE_LR = args.solver_lr
     if len(args.decay_steps) > 0:
            steps = args.decay_steps.replace(' ', ',')
            steps = steps.replace(';', ',')
            steps = steps.replace('', ',')
            steps = steps.replace('', ',')
            steps = steps.split(',')
            steps = tuple([int(x) for x in steps])
     cfg.SOLVER.STEPS = steps
     cfg.freeze()
     num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
     args.distributed = num_gpus > 1
     if args.distributed:
     torch.cuda.set_device(args.local_rank)
     torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
     )
     synchronize()
     output_dir = args.train_url
     if output_dir:
     mkdir(output_dir)
        logger = setup_logger("framework", output_dir, get_rank())
     logger.info("Using {} GPUs".format(num_gpus))
        logger.info(args)
     logger.info("Loaded configuration file {}".format(args.config_file))
     train(cfg, args.local_rank, args.distributed, args.new_iteration)
    if __name__ == "__main__":
     main()
    复制代码

    3.模型测试

    3.1预测函数

    复制代码
    from framework.engine.predictor import Predictor
    from PIL import Image,ImageDraw
    import numpy as np
    import matplotlib.pyplot as plt
    def predict(img_path,model_path): 
     config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"
     cfg.merge_from_file(config_file)
     cfg.defrost()
     cfg.MODEL.WEIGHT = model_path
     cfg.OUTPUT_DIR = None
     cfg.freeze()
        predictor = Predictor(cfg=cfg, min_image_size=800)
     src_img = Image.open(img_path)
     img = src_img.convert('RGB')
     img = np.array(img)
     img = img[:, :, ::-1]
        predictions = predictor.compute_prediction(img)
     top_predictions = predictor.select_top_predictions(predictions)
     bboxes = top_predictions.bbox.int().numpy().tolist()
     bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]
        scores = top_predictions.get_field("scores").numpy().tolist()
        scores = [round(x, 4) for x in scores]
        labels = top_predictions.get_field("labels").numpy().tolist()
        labels = [predictor.CATEGORIES[x] for x in labels]
        draw = ImageDraw.Draw(src_img)
     for i,bbox in enumerate(bboxes):
     draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))
     draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))
     return src_img
    复制代码

    3.2开始预测

    复制代码
    if __name__ == "__main__":
     model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型
     image_path = "./trained_model/model/demo_image.jpg" # 预测的图像
     img = predict(image_path,model_path)
     plt.figure(figsize=(10,10)) #设置窗口大小
     plt.imshow(img)
     plt.show()
    2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth
    复制代码

     

    点击关注,第一时间了解华为云新鲜技术~

  • 相关阅读:
    ​iOS安全加固方法及实现
    第1章 引论
    多商户商城系统功能拆解41讲-平台端应用-客服设置
    大空间享大智慧 奇瑞新能源奇瑞大蚂蚁
    深入体会线程状态的切换
    基于JavaSwing开发排课系统 毕业设计 大作业 课程设计
    5自由度雄克机械臂仿真描点
    C++无依赖库的websocket实现
    ES 实战复杂sql查询、修改字段类型
    值得收藏的30道Python练手题(附详解)
  • 原文地址:https://www.cnblogs.com/huaweiyun/p/16931547.html