• Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考


    采用三次多项式拟合生成的anchor特征点,在给定的polyfit_draw函数中,degree参数代表了拟合多项式的度数。

    具体来说,当我们使用np.polyfit函数进行数据点的多项式拟合时,我们需要指定一个度数。这个度数决定了多项式的复杂度。例如:

    • degree = 1:线性拟合,也就是最简单的直线拟合。拟合的多项式形式为 f(y)=ax+b。

    • degree = 2:二次多项式拟合。拟合的多项式形式为 f(y)=ax2+bx+c。

    • degree = 3:三次多项式拟合。拟合的多项式形式为 f(y)=ax3+bx2+cx+d。

    ...以此类推。

    度数越高,多项式越复杂,可以更准确地拟合数据点,但也更容易过拟合(即模型过于复杂,过于依赖训练数据,对新数据的适应性差)。

    1. import torch, os, cv2
    2. from utils.dist_utils import dist_print
    3. import torch, os
    4. from utils.common import merge_config, get_model
    5. import tqdm
    6. import torchvision.transforms as transforms
    7. from data.dataset import LaneTestDataset
    8. def pred2coords(pred, row_anchor, col_anchor, local_width = 1, original_image_width = 1640, original_image_height = 590):
    9. batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
    10. batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape
    11. max_indices_row = pred['loc_row'].argmax(1).cpu()
    12. # n , num_cls, num_lanes
    13. valid_row = pred['exist_row'].argmax(1).cpu()
    14. # n, num_cls, num_lanes
    15. max_indices_col = pred['loc_col'].argmax(1).cpu()
    16. # n , num_cls, num_lanes
    17. valid_col = pred['exist_col'].argmax(1).cpu()
    18. # n, num_cls, num_lanes
    19. pred['loc_row'] = pred['loc_row'].cpu()
    20. pred['loc_col'] = pred['loc_col'].cpu()
    21. coords = []
    22. row_lane_idx = [1,2]
    23. col_lane_idx = [0,3]
    24. for i in row_lane_idx:
    25. tmp = []
    26. if valid_row[0,:,i].sum() > num_cls_row / 2:
    27. for k in range(valid_row.shape[1]):
    28. if valid_row[0,k,i]:
    29. all_ind = torch.tensor(list(range(max(0,max_indices_row[0,k,i] - local_width), min(num_grid_row-1, max_indices_row[0,k,i] + local_width) + 1)))
    30. out_tmp = (pred['loc_row'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
    31. out_tmp = out_tmp / (num_grid_row-1) * original_image_width
    32. tmp.append((int(out_tmp), int(row_anchor[k] * original_image_height)))
    33. coords.append(tmp)
    34. for i in col_lane_idx:
    35. tmp = []
    36. if valid_col[0,:,i].sum() > num_cls_col / 4:
    37. for k in range(valid_col.shape[1]):
    38. if valid_col[0,k,i]:
    39. all_ind = torch.tensor(list(range(max(0,max_indices_col[0,k,i] - local_width), min(num_grid_col-1, max_indices_col[0,k,i] + local_width) + 1)))
    40. out_tmp = (pred['loc_col'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
    41. out_tmp = out_tmp / (num_grid_col-1) * original_image_height
    42. tmp.append((int(col_anchor[k] * original_image_width), int(out_tmp)))
    43. coords.append(tmp)
    44. return coords
    45. def polyfit_draw(img, coords, degree=3, color=(144, 238, 144), thickness=2):
    46. """
    47. 对车道线坐标进行多项式拟合并在图像上绘制曲线。
    48. :param img: 输入图像
    49. :param coords: 车道线坐标列表
    50. :param degree: 拟合的多项式的度数
    51. :param color: 曲线的颜色
    52. :param thickness: 曲线的宽度
    53. :return: 绘制了曲线的图像
    54. """
    55. if len(coords) == 0:
    56. return img
    57. x = [point[0] for point in coords]
    58. y = [point[1] for point in coords]
    59. # 对点进行多项式拟合
    60. coefficients = np.polyfit(y, x, degree)
    61. poly = np.poly1d(coefficients)
    62. ys = np.linspace(min(y), max(y), 100)
    63. xs = poly(ys)
    64. for i in range(len(ys) - 1):
    65. start_point = (int(xs[i]), int(ys[i]))
    66. end_point = (int(xs[i+1]), int(ys[i+1]))
    67. cv2.line(img, start_point, end_point, color, thickness)
    68. return img
    69. if __name__ == "__main__":
    70. torch.backends.cudnn.benchmark = True
    71. args, cfg = merge_config()
    72. cfg.batch_size = 1
    73. print('setting batch_size to 1 for demo generation')
    74. dist_print('start testing...')
    75. assert cfg.backbone in ['18','34','50','101','152','50next','101next','50wide','101wide']
    76. if cfg.dataset == 'CULane':
    77. cls_num_per_lane = 18
    78. elif cfg.dataset == 'Tusimple':
    79. cls_num_per_lane = 56
    80. else:
    81. raise NotImplementedError
    82. net = get_model(cfg)
    83. state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
    84. compatible_state_dict = {}
    85. for k, v in state_dict.items():
    86. if 'module.' in k:
    87. compatible_state_dict[k[7:]] = v
    88. else:
    89. compatible_state_dict[k] = v
    90. net.load_state_dict(compatible_state_dict, strict=False)
    91. net.eval()
    92. img_transforms = transforms.Compose([
    93. transforms.Resize((int(cfg.train_height / cfg.crop_ratio), cfg.train_width)),
    94. transforms.ToTensor(),
    95. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    96. ])
    97. if cfg.dataset == 'CULane':
    98. splits = ['test0_normal.txt']
    99. datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, 'list/test_split/'+split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
    100. img_w, img_h = 1570, 660
    101. elif cfg.dataset == 'Tusimple':
    102. splits = ['test.txt']
    103. datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
    104. img_w, img_h = 1280, 720
    105. else:
    106. raise NotImplementedError
    107. for split, dataset in zip(splits, datasets):
    108. loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle = False, num_workers=1)
    109. fourcc = cv2.VideoWriter_fourcc(*'MJPG')
    110. print(split[:-3]+'avi')
    111. vout = cv2.VideoWriter('4.'+'avi', fourcc , 30.0, (img_w, img_h))
    112. for i, data in enumerate(tqdm.tqdm(loader)):
    113. imgs, names = data
    114. imgs = imgs.cuda()
    115. with torch.no_grad():
    116. pred = net(imgs)
    117. vis = cv2.imread(os.path.join(cfg.data_root,names[0]))
    118. coords = pred2coords(pred, cfg.row_anchor, cfg.col_anchor, original_image_width = img_w, original_image_height = img_h)
    119. for lane in coords:
    120. # for coord in lane:
    121. # cv2.circle(vis,coord,1,(0,255,0),-1)
    122. # vis = draw_lanes(vis, coords)
    123. # polyfit_draw(vis, lane)
    124. vis = polyfit_draw(vis, lane) # 对每一条车道线都使用polyfit_draw函数
    125. vout.write(vis)
    126. vout.release()

     ps:

    优化前

    优化后

    显存利用情况

     

  • 相关阅读:
    python使用mysql添加全文索引
    Vue3中el-table表格数据不显示
    SpringBoot——原理(起步依赖+自动配置(概述和案例))
    Java------Stream流式编程常用API【.stream,filter(),map()】(三)
    记一次性能飙升的Mysql CRUD数据表迁移到Clickhouse表的过程
    OpenCV快速入门:绘制图形、图像金字塔和感兴趣区域
    编译支持国密的抓包工具 WireShark
    Django重定向类HttpResponseRedirect、HttpResponsePermanentRedirect和重定向函数redirect
    UE5- c++ websocket里实现调用player里的方法
    DETR纯代码分享(八)position_encoding.py(models)
  • 原文地址:https://blog.csdn.net/weixin_64043217/article/details/133499492