• 第12章 PyTorch图像分割代码框架-2


    模型模块

    本书的第5-9章重点介绍了各种2D3D语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+PSPNetRefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

    代码11-3 Deeplab v3+网络的主体部分

    1. # 定义Deeplab V3+类
    2. class DeepLabHeadV3Plus(nn.Module):
    3. def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
    4. super(DeepLabHeadV3Plus, self).__init__()
    5. self.project = nn.Sequential(
    6. nn.Conv2d(low_level_channels, 48, 1, bias=False),
    7. nn.BatchNorm2d(48),
    8. nn.ReLU(inplace=True),
    9. )
    10. # ASPP
    11. self.aspp = ASPP(in_channels, aspp_dilate)
    12. # classifier head
    13. self.classifier = nn.Sequential(
    14. nn.Conv2d(304, 256, 3, padding=1, bias=False),
    15. nn.BatchNorm2d(256),
    16. nn.ReLU(inplace=True),
    17. nn.Conv2d(256, num_classes, 1)
    18. )
    19. self._init_weight()
    20. # forward method
    21. def forward(self, feature):
    22. # print(feature['low_level'].shape)
    23. # print(feature['out'].shape)
    24. low_level_feature = self.project(feature['low_level'])
    25. output_feature = self.aspp(feature['out'])
    26. output_feature = F.interpolate(
    27. output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
    28. return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
    29. # weight initilize
    30. def _init_weight(self):
    31. for m in self.modules():
    32. if isinstance(m, nn.Conv2d):
    33. nn.init.kaiming_normal_(m.weight)
    34. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    35. nn.init.constant_(m.weight, 1)
    36. nn.init.constant_(m.bias, 0)

    对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

    工具函数模块

    工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py、训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

    代码11-4 工具函数示例:定义一个Focal loss

    1. # 导入相关库
    2. import torch
    3. import torch.nn as nn
    4. import torch.nn.functional as F
    5. # 定义一个Focal loss类
    6. class FocalLoss(nn.Module):
    7. def __init__(self, alpha=1, gamma=2):
    8. super(FocalLoss, self).__init__()
    9. self.alpha = alpha
    10. self.gamma = gamma
    11. def forward(self, inputs, targets):
    12. # Compute cross-entropy loss
    13. ce_loss = F.cross_entropy(inputs, targets, reduction='none')
    14. # Compute the focal loss
    15. pt = torch.exp(-ce_loss)
    16. focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
    17. return focal_loss.mean()

    配置模块

    配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

    代码11-5 argparser参数配置管理

    1. # 导入argparse库
    2. import argparse
    3. # 创建参数管理器
    4. parser = argparse.ArgumentParser()
    5. # 涉及数据相关的参数管理
    6. parser.add_argument("--data_root", type=str, default='./dataset',
    7. help="path to Dataset")
    8. parser.add_argument("--save_root", type=str, default='./',
    9. help="path to save result")
    10. parser.add_argument("--dataset", type=str, default='voc',
    11. choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
    12. parser.add_argument("--num_classes", type=int, default=None,
    13. help="num classes (default: None)")

    在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

    主函数模块

    主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6VOC数据集训练验证部分代码。

    代码11-6 主函数模块中的训练迭代部分

    1. # 初始化区间损失
    2. interval_loss = 0
    3. while True:
    4. # 执行训练
    5. model.train()
    6. cur_epochs += 1
    7. for (images, labels) in train_loader:
    8. cur_itrs += 1
    9. images = images.to(device, dtype=torch.float32)
    10. labels = labels.to(device, dtype=torch.long)
    11. optimizer.zero_grad()
    12. outputs = model(images)
    13. loss = criterion(outputs, labels)
    14. loss.backward()
    15. optimizer.step()
    16. np_loss = loss.detach().cpu().numpy()
    17. interval_loss += np_loss
    18. if vis is not None:
    19. vis.vis_scalar('Loss', cur_itrs, np_loss)
    20. # 打印训练信息
    21. if (cur_itrs) % opts.print_interval == 0:
    22. pass
    23. # 保存模型
    24. if (cur_itrs) % opts.val_interval == 0:
    25. pass
    26. # 日志记录
    27. logger.info("Save the latest model to %s" % save_path_checkpoints)
    28. # 模型验证
    29. print("validation...")
    30. model.eval()
    31. val_score, ret_samples = validate(
    32. opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
    33. ret_samples_ids=vis_sample_id)
    34. logger.info("Validation performance: %s", val_score)
    35. # 保存最优模型
    36. if val_score['mean_dice'] > best_score:
    37. best_score = val_score['mean_dice']
    38. save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %
    39. (opts.model, opts.dataset, opts.output_stride)))
    40. logger.info("Save best-performance model so far to %s" % save_path_checkpoints)
    41. # 训练过程可视化
    42. if vis is not None:
    43. vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
    44. vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
    45. vis.vis_table("[Val] Class IoU", val_score['Class IoU'])
    46. for k, (img, target, lbl) in enumerate(ret_samples):
    47. img = (denorm(img) * 255).astype(np.uint8)
    48. target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
    49. lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
    50. concat_img = np.concatenate((img, target, lbl), axis=2)
    51. vis.vis_image('Sample %d' % k, concat_img)
    52.     scheduler.step()

    在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

    4b3c438a10649b752070ff5d0c7ce8fa.png

    11-2 VOC训练过程信息

    11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

    9907f54a46275eafc2e600b66157fc84.png

    11-3 Deeplab v3+训练过程可视化

    11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

    dcb6cef8c0b0abdf98f2c64c97552bb0.png

    11-4 验证集模型效果图

    后续全书内容和代码将在github上开源,请关注仓库:

    https://github.com/luwill/Deep-Learning-Image-Segmentation

    (未完待续)

  • 相关阅读:
    树(二叉查找树BST、二叉平衡树AVL、红黑树R-B)
    UGUI源码解析——RawImage
    【python】--python基础学习
    xshell连接虚拟机慢
    IO流-数据流
    微服务的快速开始(nacos)最全快速配置图解
    怎么将vue的移动端项目打包成手机的app软件apk格式的。hbuilderx 云打包。
    SSH远程登录网络设备
    【python】系列之item.taobao 获取商品详情API接口调用
    在阿里干了6年自动化测试,30岁即将退休的我,告诉你自动化测试工程师有多吃香...
  • 原文地址:https://blog.csdn.net/weixin_37737254/article/details/134257973