• 训练DeeplabV3+来分割车道线


    本例我们训练DeepLabV3+语义分割模型来分割车道线。

    ead64233424728797832a74e8d4cbab4.png

    DeepLabV3+模型的原理有以下一些要点:

    1,采用Encoder-Decoder架构。

    2,Encoder使用类似Xception的结构作为backbone。

    3,Encoder还使用ASPP(Atrous Spatial Pyramid Pooling),即空洞卷积空间金字塔池化,来实现不同尺度的特征融合,ASPP由4个不同rate的空洞卷积和一个全局池化组成。

    4,Decoder再次使用跨层级的concat操作进行高低层次的特征融合。

    1. #!pip install segmentation_models_pytorch
    2. #!pip install albumentations
    1. import torchkeras 
    2. from argparse import Namespace
    3. config = Namespace(
    4.     img_size = 128
    5.     lr = 1e-4,
    6.     batch_size = 4,
    7. )

    一,准备数据

    公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和车道线数据集下载链接。

    1. from pathlib import Path
    2. from PIL import Image
    3. import numpy as np 
    4. import torch 
    5. from torch import nn 
    6. from torch.utils.data import Dataset,DataLoader 
    7. import os 
    8. from torchkeras.data import resize_and_pad_image 
    9. from torchkeras.plots import joint_imgs_col 
    10. class MyDataset(Dataset):
    11.     def __init__(self, img_files, img_size, transforms = None):
    12.         self.__dict__.update(locals())
    13.         
    14.     def __len__(self) -> int:
    15.         return len(self.img_files)
    16.     def get(self, index):
    17.         img_path = self.img_files[index]
    18.         mask_path = img_path.replace('images','masks').replace('.jpg','.png')
    19.         image = Image.open(img_path).convert('RGB')
    20.         mask = Image.open(mask_path).convert('L')
    21.         return image, mask
    22.     
    23.     def __getitem__(self, index):
    24.         
    25.         image,mask = self.get(index)
    26.         
    27.         image = resize_and_pad_image(image,self.img_size,self.img_size)
    28.         mask = resize_and_pad_image(mask,self.img_size,self.img_size)
    29.         
    30.         image_arr = np.array(image, dtype=np.float32)/255.0
    31.         
    32.         mask_arr = np.array(mask,dtype=np.float32)
    33.         mask_arr = np.where(mask_arr>100.0,1.0,0.0).astype(np.int64)
    34.         
    35.         sample = {
    36.             "image": image_arr,
    37.             "mask": mask_arr
    38.         }
    39.         
    40.         if self.transforms is not None:
    41.             sample = self.transforms(**sample)
    42.             
    43.         sample['mask'] = sample['mask'][None,...]
    44.             
    45.         return sample
    46.     
    47.     def show_sample(self, index):
    48.         image, mask = self.get(index)
    49.         image_result = joint_imgs_col(image,mask)
    50.         return image_result
    1. import albumentations as A
    2. from albumentations.pytorch.transforms import ToTensorV2
    3. def get_train_transforms():
    4.     return A.Compose(
    5.         [
    6.             A.OneOf([A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.5)]),
    7.             ToTensorV2(p=1),
    8.         ],
    9.         p=1.0
    10.     )
    11. def get_val_transforms():
    12.     return A.Compose(
    13.         [
    14.             ToTensorV2(p=1),
    15.         ],
    16.         p=1.0
    17.     )
    1. train_transforms=get_train_transforms()
    2. val_transforms=get_val_transforms()
    3. ds_train = MyDataset(train_imgs,img_size=config.img_size,transforms=train_transforms)
    4. ds_val = MyDataset(val_imgs,img_size=config.img_size,transforms=val_transforms)
    5. dl_train = DataLoader(ds_train,batch_size=config.batch_size)
    6. dl_val = DataLoader(ds_val,batch_size=config.batch_size)
    ds_train.show_sample(10)

    13afdd40413852e18552320016680739.png

    二,定义模型

    1. import torch 
    2. num_classes = 1
    3. net = smp.DeepLabV3Plus(
    4.     encoder_name="mobilenet_v2", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    5.     encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
    6.     in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    7.     classes=num_classes,            # model output channels (number of classes in your dataset)
    8. )
    9. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    三,训练模型

    下面使用我们的梦中情炉~torchkeras~来实现最优雅的训练循环。😋😋

    1. from torchkeras import KerasModel 
    2. from torch.nn import functional as F 
    3. # 由于输入数据batch结构差异,需要重写StepRunner并覆盖
    4. class StepRunner:
    5.     def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
    6.                  optimizer = None, lr_scheduler = None
    7.                  ):
    8.         self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
    9.         self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
    10.         self.accelerator = accelerator
    11.         
    12.         if self.stage=='train':
    13.             self.net.train() 
    14.         else:
    15.             self.net.eval()
    16.             
    17.     
    18.     def __call__(self, batch):
    19.         features,labels = batch['image'],batch['mask'
    20.         
    21.         #loss
    22.         preds = self.net(features)
    23.         loss = self.loss_fn(preds,labels)
    24.         #backward()
    25.         if self.optimizer is not None and self.stage=="train":
    26.             self.accelerator.backward(loss)
    27.             self.optimizer.step()
    28.             if self.lr_scheduler is not None:
    29.                 self.lr_scheduler.step()
    30.             self.optimizer.zero_grad()
    31.             
    32.         all_preds = self.accelerator.gather(preds)
    33.         all_labels = self.accelerator.gather(labels)
    34.         all_loss = self.accelerator.gather(loss).sum()
    35.         
    36.         #losses
    37.         step_losses = {self.stage+"_loss":all_loss.item()}
    38.         
    39.         #metrics
    40.         step_metrics = {self.stage+"_"+name:metric_fn(all_preds, all_labels).item() 
    41.                         for name,metric_fn in self.metrics_dict.items()}
    42.         
    43.         if self.optimizer is not None and self.stage=="train":
    44.             step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
    45.             
    46.         return step_losses,step_metrics
    47. KerasModel.StepRunner = StepRunner
    1. from torchkeras.metrics import IOU
    2. class DiceLoss(nn.Module):
    3.     def __init__(self,smooth=0.001,num_classes=1,weights = None):
    4.         ...
    5.     def forward(self, logits, targets):
    6.         
    7.         ...
    8.         
    9.     def compute_loss(self,preds,targets):
    10.         ...
    11.     
    12.     
    13. class MixedLoss(nn.Module):
    14.     def __init__(self,bce_ratio=0.5):
    15.         super().__init__()
    16.         self.bce = nn.BCEWithLogitsLoss()
    17.         self.dice = DiceLoss()
    18.         self.bce_ratio = bce_ratio
    19.         
    20.     def forward(self,logits,targets):
    21.         bce_loss = self.bce(logits,targets.float())
    22.         dice_loss = self.dice(logits,targets)
    23.         total_loss = bce_loss*self.bce_ratio + dice_loss*(1-self.bce_ratio)
    24.         return total_loss
    1. optimizer = torch.optim.AdamW(net.parameters(), lr=config.lr)
    2. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    3.     optimizer = optimizer,
    4.     T_max=8,
    5.     eta_min=0
    6. )
    7. metrics_dict = {'iou': IOU(num_classes=1)}
    8. model = KerasModel(net,
    9.                    loss_fn=MixedLoss(bce_ratio=0.5),
    10.                    metrics_dict=metrics_dict,
    11.                    optimizer=optimizer,
    12.                    lr_scheduler = lr_scheduler
    13.                   )
    1. from torchkeras.kerascallbacks import WandbCallback
    2. wandb_cb = WandbCallback(project='unet_lane',
    3.                          config=config.__dict__,
    4.                          name=None,
    5.                          save_code=True,
    6.                          save_ckpt=True)
    7. dfhistory=model.fit(train_data=dl_train, 
    8.                     val_data=dl_val, 
    9.                     epochs=100
    10.                     ckpt_path='checkpoint.pt',
    11.                     patience=10
    12.                     monitor="val_iou",
    13.                     mode="max",
    14.                     mixed_precision='no',
    15.                     callbacks = [wandb_cb],
    16.                     plot = True 
    17.                    )

    <<<<<< ⚡️ cuda is used >>>>>>

    7bec7d9048d769278fba9fb625ec7365.png

    1. ================================================================================2023-05-21 20:45:27
    2. Epoch 1 / 100
    3. 100%|████████████████████| 20/20 [00:03<00:00, 6.60it/s, lr=5e-5, train_iou=0.15, train_loss=0.873]
    4. 100%|██████████████████████████████████| 5/5 [00:00<00:00, 8.54it/s, val_iou=0.162, val_loss=0.836]
    5. [0;31m<<<<<< reach best val_iou : 0.16249321401119232 >>>>>>[0m
    6. ================================================================================2023-05-21 20:45:30
    7. Epoch 2 / 100
    8. 100%|███████████████████████| 20/20 [00:02<00:00, 7.24it/s, lr=0, train_iou=0.25, train_loss=0.836]
    9. 100%|██████████████████████████████████| 5/5 [00:00<00:00, 8.49it/s, val_iou=0.291, val_loss=0.821]
    10. [0;31m<<<<<< reach best val_iou : 0.2905024290084839 >>>>>>[0m
    11. ================================================================================2023-05-21 20:51:06
    12. Epoch 95 / 100
    13. 100%|███████████████████| 20/20 [00:02<00:00, 7.21it/s, lr=5e-5, train_iou=0.721, train_loss=0.187]
    14. 100%|██████████████████████████████████| 5/5 [00:00<00:00, 8.71it/s, val_iou=0.665, val_loss=0.249]

    四,评估模型

    1. metrics_dict = {'iou': IOU(num_classes=1,if_print=True)}
    2. model = KerasModel(net,
    3.                    loss_fn=MixedLoss(bce_ratio=0.5),
    4.                    metrics_dict=metrics_dict,
    5.                    optimizer=optimizer,
    6.                    lr_scheduler = lr_scheduler
    7.                   )
    model.evaluate(dl_val)
    1. 100%|██████████████████████████████████| 5/5 [00:00<00:00, 8.91it/s, val_iou=0.667, val_loss=0.252]
    2. global correct: 0.9912
    3. IoU: ['0.9911', '0.3422']
    4. mean IoU: 0.6667

    五,使用模型

    1. batch = next(iter(dl_val))
    2. with torch.no_grad():
    3.     model.eval()
    4.     logits = model(batch["image"].cuda())
    5.     
    6. pr_masks = logits.sigmoid()
    1. from matplotlib import pyplot as plt 
    2. for image, gt_mask, pr_mask in zip(batch["image"], batch["mask"], pr_masks):
    3.     plt.figure(figsize=(1610))
    4.     plt.subplot(131)
    5.     plt.imshow(image.numpy().transpose(120))  # convert CHW -> HWC
    6.     plt.title("Image")
    7.     plt.axis("off")
    8.     plt.subplot(132)
    9.     plt.imshow(gt_mask.numpy().squeeze()) 
    10.     plt.title("Ground truth")
    11.     plt.axis("off")
    12.     plt.subplot(133)
    13.     plt.imshow(pr_mask.cpu().numpy().squeeze()) 
    14.     plt.title("Prediction")
    15.     plt.axis("off")
    16.     plt.show()

    543b40a4a1126920496567338f4973ae.png

    a360cb249f45048b509963d4a0e085e7.png

    d451854564cd1d290049d8aa386f895c.png

    a413800b49de3cbe8360f6d89906f0f1.png

    六,保存模型

    torch.save(model.net.state_dict(),'deeplab_v3_plus.pt')

    公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和车道线数据集下载链接。

    万水千山总是情,点个赞赞行不行?😋😋

  • 相关阅读:
    uniapp中swiper 轮播带左右箭头,点击切换轮播效果demo(整理)
    每日一面系列之volatile 的理解
    基于Vue+Node+MySQL的美食菜谱食材网站设计与实现
    排序算法-归并排序
    Cy3/5/7标记多肽/PEG/聚合物/磷脂----为华生物
    云原生周刊:KubeSphere 3.4.1 发布 | 2023.11.13
    [R] Underline your idea with ggplot2
    SpringBoot入门
    后端各层的部署开发
    【深度学习】基于卷积神经网络(tensorflow)的人脸识别项目(二)
  • 原文地址:https://blog.csdn.net/Python_Ai_Road/article/details/130896298