• LPRNet, 车牌识别网络


    在这里插入图片描述
    整个车牌识别有两部分组成,一个是目标检测部分,可以用yolov4等,另一个部分就是车牌识别部分,用LPRNet。
    LPRNet 的官方github是 LPRNet.py
    LPRNet主要需要了解三个部分
    分别是 1. STN网络部分;2.主体网络部分 3.Loss部分

    主体网络backbone属于轻量级模型,其中基础模块叫small_basic_block,具体参考下面注释

    import torch.nn as nn
    import torch
    
    class small_basic_block(nn.Module):
        def __init__(self, ch_in, ch_out):
            super(small_basic_block, self).__init__()
            self.block = nn.Sequential(
                nn.Conv2d(ch_in, ch_out // 4, kernel_size=1), 
                #1x1的average pooling,降维和减少参数
                #下面经过3x1和1x3卷积的学习 [替代3x3卷积],然后再进行升维
                nn.ReLU(),
                nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
                nn.ReLU(),
                nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
                nn.ReLU(),
                nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
            )
        def forward(self, x):
            return self.block(x)
    
    class LPRNet(nn.Module):
        def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
            super(LPRNet, self).__init__()
            self.phase = phase
            self.lpr_max_len = lpr_max_len
            self.class_num = class_num
            self.backbone = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),    # 0  [bs,3,24,94] -> [bs,64,22,92]
                nn.BatchNorm2d(num_features=64),                                       # 1  -> [bs,64,22,92]
                nn.ReLU(),                                                             # 2  -> [bs,64,22,92]
                nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),                 # 3  -> [bs,64,20,90]
                small_basic_block(ch_in=64, ch_out=128),                               # 4  -> [bs,128,20,90]
                nn.BatchNorm2d(num_features=128),                                      # 5  -> [bs,128,20,90]
                nn.ReLU(),                                                             # 6  -> [bs,128,20,90]
                nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),                 # 7  -> [bs,64,18,44]
                small_basic_block(ch_in=64, ch_out=256),                               # 8  -> [bs,256,18,44]
                nn.BatchNorm2d(num_features=256),                                      # 9  -> [bs,256,18,44]
                nn.ReLU(),                                                             # 10 -> [bs,256,18,44]
                small_basic_block(ch_in=256, ch_out=256),                              # 11 -> [bs,256,18,44]
                nn.BatchNorm2d(num_features=256),                                      # 12 -> [bs,256,18,44]
                nn.ReLU(),                                                             # 13 -> [bs,256,18,44]
                nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)),                 # 14 -> [bs,64,16,21]
                nn.Dropout(dropout_rate),  # 0.5 dropout rate                          # 15 -> [bs,64,16,21]
                nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),   # 16 -> [bs,256,16,18]
                nn.BatchNorm2d(num_features=256),                                            # 17 -> [bs,256,16,18]
                nn.ReLU(),                                                                   # 18 -> [bs,256,16,18]
                nn.Dropout(dropout_rate),  # 0.5 dropout rate                                  19 -> [bs,256,16,18]
                nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1),  # class_num=68  20  -> [bs,68,4,18]
                nn.BatchNorm2d(num_features=class_num),                                             # 21 -> [bs,68,4,18]
                nn.ReLU(),                                                                          # 22 -> [bs,68,4,18]
            )
            self.container = nn.Sequential(
                nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)),
                # nn.BatchNorm2d(num_features=self.class_num),
                # nn.ReLU(),
                # nn.Conv2d(in_channels=self.class_num, out_channels=self.lpr_max_len+1, kernel_size=3, stride=2),
                # nn.ReLU(),
            )
        def forward(self, x):
            keep_features = list()
            for i, layer in enumerate(self.backbone.children()):
                x = layer(x)
                if i in [2, 6, 13, 22]: #2: [bs,64,22,92]  6:[bs,128,20,90] 13:[bs,256,18,44] 22:[bs,68,4,18]
                    keep_features.append(x)
            global_context = list()
            # keep_features: [bs,64,22,92]  [bs,128,20,90] [bs,256,18,44] [bs,68,4,18]
            for i, f in enumerate(keep_features):
                if i in [0, 1]:
                    # [bs,64,22,92] -> [bs,64,4,18]
                    # [bs,128,20,90] -> [bs,128,4,18]
                    f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
                if i in [2]:
                    # [bs,256,18,44] -> [bs,256,4,18]
                    f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
                # 没看懂这是在干嘛?有上面的avg提取上下文信息不久可以了?
                f_pow = torch.pow(f, 2)     # [bs,64,4,18]  所有元素求平方
                f_mean = torch.mean(f_pow)  # 1 所有元素求平均
                f = torch.div(f, f_mean)    # [bs,64,4,18]  所有元素除以这个均值 
                global_context.append(f)
            x = torch.cat(global_context, 1) #[bs,64,4,18]+[bs,128,4,18]+[bs,256,4,18]+[bs,68,4,18]=[bs,516,4,18]
            x = self.container(x)  # [bs,516,4,18] -> [bs, 68, 4, 18]   head头
            logits = torch.mean(x, dim=2)  # -> [bs, 68, 18]  # 68 字符类别数   18字符序列长度
            return logits
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
             '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
             '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
             '新',
             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
             'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
             'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
             'W', 'X', 'Y', 'Z', 'I', 'O', '-'
             ]
    len(CHARS) = 68
    def sparse_tuple_for_ctc(T_length, lengths):
        input_lengths = []
        target_lengths = []
        for ch in lengths:
            input_lengths.append(T_length)
            target_lengths.append(ch)
        return tuple(input_lengths), tuple(target_lengths)
    
    for iteration in range(start_iter, max_iter):
    	.........
    	images, labels, lengths = next(batch_iterator) #labels是[3,44,68,33,22,55,36,39]代表8个字符,length=8
    	input_lengths, target_lengths = sparse_tuple_for_ctc(T_length, lengths) # T_length=18, length=8
    	#input_lengths bsx18 target_lengths=bsx8
            # forward
            logits = lprnet(images) # [bs. 68. 18] 64是字符串个数,18是字符序列长度
            log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C  [18,bs,68]
            log_probs = log_probs.log_softmax(2).requires_grad_()
            # log_probs = log_probs.detach().requires_grad_()
            # print(log_probs.shape)
            # backprop
            optimizer.zero_grad()
            loss = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
            # 【18,bs,68】【bs,8】【bsx18】【bsx8】
            注意标签可能是变长的 比如 
            18, 45, 33, 37, 40, 49, 63  -->> 车牌 “湘E269JY”
    		4, 54, 51, 34, 53, 37, 38   -->> 车牌 “冀PL3N67”
    		22, 56, 37, 38,33, 39, 34, 46  -->> 车牌 “川R67283F”
    		2, 41, 44, 37, 39, 35, 33, 40  -->> 车牌 “津AD68429”
    		长度分别是7 7 8 8 代表labels
            .........
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    关于ctc_loss:

    ctc_loss = nn.CTCLoss() #下面的N=batch size
    log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
    #T=50 N=16 C=20 ;  
    targets = torch.randint(1, 20, (16, 30), dtype=torch.long) # [16,30]
    input_lengths = torch.full((16,), 50, dtype=torch.long) #1x16 
    target_lengths = torch.randint(10,30,(16,), dtype=torch.long)#16x1
    loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
    loss.backward()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;

    targets为shape是(N, S)的张量 ,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;

    input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;

    target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;

    关于ctc loss的解释 https://zhuanlan.zhihu.com/p/67415439

    https://blog.csdn.net/ckqsars/article/details/108312750?spm=1001.2101.3001.6650.10&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-10-108312750-blog-106143755.pc_relevant_aa_2&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-10-108312750-blog-106143755.pc_relevant_aa_2&utm_relevant_index=11

    https://blog.csdn.net/qq_38253797/article/details/125054464

    https://blog.csdn.net/weixin_39027619/article/details/106143755

  • 相关阅读:
    stm32F103移植FreeRTOS V10.2.1打印任务堆栈信息和任务运行时间统计
    CDH大数据平台 ModuleNotFoundError: No module named ‘_sqlite3‘
    SpringBoot配置文件
    UNIAPP实战项目笔记42 购物车页面新增收货地址
    SCI一区 | Matlab实现PSO-TCN-LSTM-Attention粒子群算法优化时间卷积长短期记忆神经网络融合注意力机制多变量时间序列预测
    linux http代理设置
    华为要用MateBook E Go系列开辟一个新市场
    使用HttpClients发送Get请求和Post请求【含原理分析】
    基于Matlab的汽车安全应用轨道融合仿真(附源码)
    【恋上数据结构与算法】理论 二:动态数组
  • 原文地址:https://blog.csdn.net/Bismarckczy/article/details/126096826