• FCOS论文复现:通用物体检测算法


    摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

    本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

    FCOS:Fully Convolutional One-Stage Object Detection

    本案例代码是FCOS论文复现的体验案例

    此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

    具体的算法介绍:AI Gallery_算法_模型_云市场-华为云

    注意事项:

    1.本案例使用框架: PyTorch1.0.0

    2.本案例使用硬件: GPU

    3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

    1.数据和代码下载

    1. import os
    2. import moxing as mox
    3. # 数据代码下载
    4. mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
    5. # 解压缩
    6. os.system('unzip FCOS.zip -d ./')

    2.模型训练

    2.1依赖库安装及加载

    1. """
    2. Basic training script for PyTorch
    3. """
    4. # Set up custom environment before nearly anything else is imported
    5. # NOTE: this should be the first import (no not reorder)
    6. import os
    7. import argparse
    8. import torch
    9. import shutil
    10. src_dir = './FCOS/'
    11. os.chdir(src_dir)
    12. os.system('pip install -r ./pip-requirements.txt')
    13. os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
    14. os.system('python setup.py build develop')
    15. from framework.utils.env import setup_environment
    16. from framework.config import cfg
    17. from framework.data import make_data_loader
    18. from framework.solver import make_lr_scheduler
    19. from framework.solver import make_optimizer
    20. from framework.engine.inference import inference
    21. from framework.engine.trainer import do_train
    22. from framework.modeling.detector import build_detection_model
    23. from framework.utils.checkpoint import DetectronCheckpointer
    24. from framework.utils.collect_env import collect_env_info
    25. from framework.utils.comm import synchronize, \
    26. get_rank, is_pytorch_1_1_0_or_later
    27. from framework.utils.logger import setup_logger
    28. from framework.utils.miscellaneous import mkdir

    2.2训练函数

    1. def train(cfg, local_rank, distributed, new_iteration=False):
    2. model = build_detection_model(cfg)
    3. device = torch.device(cfg.MODEL.DEVICE)
    4. model.to(device)
    5. if cfg.MODEL.USE_SYNCBN:
    6. assert is_pytorch_1_1_0_or_later(), \
    7. "SyncBatchNorm is only available in pytorch >= 1.1.0"
    8. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    9. optimizer = make_optimizer(cfg, model)
    10. scheduler = make_lr_scheduler(cfg, optimizer)
    11. if distributed:
    12. model = torch.nn.parallel.DistributedDataParallel(
    13. model, device_ids=[local_rank], output_device=local_rank,
    14. # this should be removed if we update BatchNorm stats
    15. broadcast_buffers=False,
    16. )
    17. arguments = {}
    18. arguments["iteration"] = 0
    19. output_dir = cfg.OUTPUT_DIR
    20. save_to_disk = get_rank() == 0
    21. checkpointer = DetectronCheckpointer(
    22. cfg, model, optimizer, scheduler, output_dir, save_to_disk
    23. )
    24. print(cfg.MODEL.WEIGHT)
    25. extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)
    26. print(extra_checkpoint_data)
    27. arguments.update(extra_checkpoint_data)
    28. if new_iteration:
    29. arguments["iteration"] = 0
    30. data_loader = make_data_loader(
    31. cfg,
    32. is_train=True,
    33. is_distributed=distributed,
    34. start_iter=arguments["iteration"],
    35. )
    36. do_train(
    37. model,
    38. data_loader,
    39. optimizer,
    40. scheduler,
    41. checkpointer,
    42. device,
    43. arguments,
    44. )
    45. return model

    2.3设置参数,开始训练

    1. def main():
    2. parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    3. parser.add_argument(
    4. '--train_url',
    5. default='./outputs',
    6. type=str,
    7. help='the path to save training outputs'
    8. )
    9. parser.add_argument(
    10. "--config-file",
    11. default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",
    12. metavar="FILE",
    13. help="path to config file",
    14. type=str,
    15. )
    16. parser.add_argument("--local_rank", type=int, default=0)
    17. parser.add_argument('--train_iterations', default=0, type=int)
    18. parser.add_argument('--warmup_iterations', default=500, type=int)
    19. parser.add_argument('--train_batch_size', default=8, type=int)
    20. parser.add_argument('--solver_lr', default=0.01, type=float)
    21. parser.add_argument('--decay_steps', default='120000,160000', type=str)
    22. parser.add_argument('--new_iteration',default=False, action='store_true')
    23. args, unknown = parser.parse_known_args()
    24. cfg.merge_from_file(args.config_file)
    25. # load the model trained on MS-COCO
    26. if args.train_iterations > 0:
    27. cfg.SOLVER.MAX_ITER = args.train_iterations
    28. if args.warmup_iterations > 0:
    29. cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations
    30. if args.train_batch_size > 0:
    31. cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size
    32. if args.solver_lr > 0:
    33. cfg.SOLVER.BASE_LR = args.solver_lr
    34. if len(args.decay_steps) > 0:
    35. steps = args.decay_steps.replace(' ', ',')
    36. steps = steps.replace(';', ',')
    37. steps = steps.replace(';', ',')
    38. steps = steps.replace(',', ',')
    39. steps = steps.split(',')
    40. steps = tuple([int(x) for x in steps])
    41. cfg.SOLVER.STEPS = steps
    42. cfg.freeze()
    43. num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    44. args.distributed = num_gpus > 1
    45. if args.distributed:
    46. torch.cuda.set_device(args.local_rank)
    47. torch.distributed.init_process_group(
    48. backend="nccl", init_method="env://"
    49. )
    50. synchronize()
    51. output_dir = args.train_url
    52. if output_dir:
    53. mkdir(output_dir)
    54. logger = setup_logger("framework", output_dir, get_rank())
    55. logger.info("Using {} GPUs".format(num_gpus))
    56. logger.info(args)
    57. logger.info("Loaded configuration file {}".format(args.config_file))
    58. train(cfg, args.local_rank, args.distributed, args.new_iteration)
    59. if __name__ == "__main__":
    60. main()

    3.模型测试

    3.1预测函数

    1. from framework.engine.predictor import Predictor
    2. from PIL import Image,ImageDraw
    3. import numpy as np
    4. import matplotlib.pyplot as plt
    5. def predict(img_path,model_path):
    6. config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"
    7. cfg.merge_from_file(config_file)
    8. cfg.defrost()
    9. cfg.MODEL.WEIGHT = model_path
    10. cfg.OUTPUT_DIR = None
    11. cfg.freeze()
    12. predictor = Predictor(cfg=cfg, min_image_size=800)
    13. src_img = Image.open(img_path)
    14. img = src_img.convert('RGB')
    15. img = np.array(img)
    16. img = img[:, :, ::-1]
    17. predictions = predictor.compute_prediction(img)
    18. top_predictions = predictor.select_top_predictions(predictions)
    19. bboxes = top_predictions.bbox.int().numpy().tolist()
    20. bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]
    21. scores = top_predictions.get_field("scores").numpy().tolist()
    22. scores = [round(x, 4) for x in scores]
    23. labels = top_predictions.get_field("labels").numpy().tolist()
    24. labels = [predictor.CATEGORIES[x] for x in labels]
    25. draw = ImageDraw.Draw(src_img)
    26. for i,bbox in enumerate(bboxes):
    27. draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))
    28. draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))
    29. return src_img

    3.2开始预测

    1. if __name__ == "__main__":
    2. model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型
    3. image_path = "./trained_model/model/demo_image.jpg" # 预测的图像
    4. img = predict(image_path,model_path)
    5. plt.figure(figsize=(10,10)) #设置窗口大小
    6. plt.imshow(img)
    7. plt.show()
    8. 2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth

    点击关注,第一时间了解华为云新鲜技术~

     

  • 相关阅读:
    Hashtable为什么效率很低
    Dapr v1.13 版本已发布
    linux x64 下的redis安装
    chatgpt技术总结(包括transformer,注意力机制,迁移学习,Ray,TensorFlow,Pytorch)
    java计算机毕业设计网上宠物售卖平台源代码+数据库+系统+lw文档
    HarmonyOS—编译构建概述
    【ArcGIS Pro二次开发】(72):PPT文件操作方法汇总
    postgresql数据库docker
    ECC(SM2) 简介及 C# 和 js 实现【加密知多少系列】
    解决爬虫在重定向(Redirect)情况下,URL没有变化的方法
  • 原文地址:https://blog.csdn.net/devcloud/article/details/128074766