摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程
本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。
FCOS:Fully Convolutional One-Stage Object Detection
本案例代码是FCOS论文复现的体验案例
此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程
注意事项:
1.本案例使用框架: PyTorch1.0.0
2.本案例使用硬件: GPU
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
1.数据和代码下载
import os import moxing as mox # 数据代码下载 mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip') # 解压缩 os.system('unzip FCOS.zip -d ./')
2.模型训练
2.1依赖库安装及加载
""" Basic training script for PyTorch """ # Set up custom environment before nearly anything else is imported # NOTE: this should be the first import (no not reorder) import os import argparse import torch import shutil src_dir = './FCOS/' os.chdir(src_dir) os.system('pip install -r ./pip-requirements.txt') os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl') os.system('python setup.py build develop') from framework.utils.env import setup_environment from framework.config import cfg from framework.data import make_data_loader from framework.solver import make_lr_scheduler from framework.solver import make_optimizer from framework.engine.inference import inference from framework.engine.trainer import do_train from framework.modeling.detector import build_detection_model from framework.utils.checkpoint import DetectronCheckpointer from framework.utils.collect_env import collect_env_info from framework.utils.comm import synchronize, \ get_rank, is_pytorch_1_1_0_or_later from framework.utils.logger import setup_logger from framework.utils.miscellaneous import mkdir
2.2训练函数
def train(cfg, local_rank, distributed, new_iteration=False): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) print(cfg.MODEL.WEIGHT) extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT) print(extra_checkpoint_data) arguments.update(extra_checkpoint_data) if new_iteration: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) do_train( model, data_loader, optimizer, scheduler, checkpointer, device, arguments, ) return model
2.3设置参数,开始训练
def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument( '--train_url', default='./outputs', type=str, help='the path to save training outputs' ) parser.add_argument( "--config-file", default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument('--train_iterations', default=0, type=int) parser.add_argument('--warmup_iterations', default=500, type=int) parser.add_argument('--train_batch_size', default=8, type=int) parser.add_argument('--solver_lr', default=0.01, type=float) parser.add_argument('--decay_steps', default='120000,160000', type=str) parser.add_argument('--new_iteration',default=False, action='store_true') args, unknown = parser.parse_known_args() cfg.merge_from_file(args.config_file) # load the model trained on MS-COCO if args.train_iterations > 0: cfg.SOLVER.MAX_ITER = args.train_iterations if args.warmup_iterations > 0: cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations if args.train_batch_size > 0: cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size if args.solver_lr > 0: cfg.SOLVER.BASE_LR = args.solver_lr if len(args.decay_steps) > 0: steps = args.decay_steps.replace(' ', ',') steps = steps.replace(';', ',') steps = steps.replace(';', ',') steps = steps.replace(',', ',') steps = steps.split(',') steps = tuple([int(x) for x in steps]) cfg.SOLVER.STEPS = steps cfg.freeze() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) synchronize() output_dir = args.train_url if output_dir: mkdir(output_dir) logger = setup_logger("framework", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Loaded configuration file {}".format(args.config_file)) train(cfg, args.local_rank, args.distributed, args.new_iteration) if __name__ == "__main__": main()
3.模型测试
3.1预测函数
from framework.engine.predictor import Predictor from PIL import Image,ImageDraw import numpy as np import matplotlib.pyplot as plt def predict(img_path,model_path): config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml" cfg.merge_from_file(config_file) cfg.defrost() cfg.MODEL.WEIGHT = model_path cfg.OUTPUT_DIR = None cfg.freeze() predictor = Predictor(cfg=cfg, min_image_size=800) src_img = Image.open(img_path) img = src_img.convert('RGB') img = np.array(img) img = img[:, :, ::-1] predictions = predictor.compute_prediction(img) top_predictions = predictor.select_top_predictions(predictions) bboxes = top_predictions.bbox.int().numpy().tolist() bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes] scores = top_predictions.get_field("scores").numpy().tolist() scores = [round(x, 4) for x in scores] labels = top_predictions.get_field("labels").numpy().tolist() labels = [predictor.CATEGORIES[x] for x in labels] draw = ImageDraw.Draw(src_img) for i,bbox in enumerate(bboxes): draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0)) draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0)) return src_img
3.2开始预测
if __name__ == "__main__": model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型 image_path = "./trained_model/model/demo_image.jpg" # 预测的图像 img = predict(image_path,model_path) plt.figure(figsize=(10,10)) #设置窗口大小 plt.imshow(img) plt.show() 2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth