• TVRNet网络PyTorch实现


    文章地址

    • An End-to-End Traffic Visibility Regression Algorithm
    • 文章通过训练搜集得到的真实道路图像数据集(Actual Road dense image Dataset, ARD),通过专业的能见度计和多人标注,获得可靠的能见度标签数据集。构建网络,进行训练,获得了较好的能见度识别网络。网络包括特征提取​、多尺度映射​、特征融合​、非线性输出(回归范围为[0,1],需要经过(0,0),(1,1)改用修改的sigmoid函数,相较于ReLU更好)。结构如下​
      在这里插入图片描述

    网络各层结构

    在这里插入图片描述

    • 我认为红框位置与之相应的参数不匹配,在Feature Extraction部分Reshape之后得到的特征图大小为4124124。紧接着接了一个卷积层Conv,显示输入是3128128
    • 第二处红框,MaxPool的kernel设置为88,特征图没有进行padding,到全连接层的输入变为64117*117,参数不对应
      在这里插入图片描述

    代码实现

    """
        Based on the ideas of the below paper, using PyTorch to build TVRNet.
        Reference: Qin H, Qin H. An end-to-end traffic visibility regression algorithm[J]. IEEE Access, 2021, 10: 25448-25454.​
        @muyeqingfeng
    """
    
    import torch
    from torch import nn
    import math
    
    
    class Inception(nn.Module):
        def __init__(self, in_planes, out_planes):
            super(Inception, self).__init__()
            self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)
            self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1)
            self.conv5 = nn.Conv2d(in_planes, out_planes, kernel_size=5, padding=2)
            self.conv7 = nn.Conv2d(in_planes, out_planes, kernel_size=7, padding=3)
    
        def forward(self, x):
            out_1 = self.conv1(x)
            out_3 = self.conv3(x)
            out_5 = self.conv5(x)
            out_7 = self.conv7(x)
    
            out = torch.cat((out_1, out_3, out_5, out_7), dim=1)
            return out
    
    def modify_sigmoid(x):
        return 1 / (1 + torch.exp(-10*(x-0.5)))
    
    class TVRNet(nn.Module):
        def __init__(self, in_planes, out_planes):
            super(TVRNet, self).__init__()
            # (B, 3, 224, 224)  ——>  (B, 3, 220, 220)
            self.FeatureExtraction_onestep = nn.Sequential(nn.Conv2d(in_planes, 20, kernel_size=5, padding=0),
                                                           nn.ReLU(inplace=True),)
            self.FeatureExtraction_maxpool = nn.MaxPool2d((5, 1))
    
            self.MultiScaleMapping = nn.Sequential(Inception(4, 16),
                                                   nn.ReLU(inplace=True),
                                                   nn.MaxPool2d(kernel_size=8))
    
            self.FeatureIntegration = nn.Sequential(nn.Linear(46656, 100),
                                                    nn.ReLU(inplace=True),
                                                    nn.Dropout(0.4),
                                                    nn.Linear(100, out_planes))
    
            self.NonLinearRegression = modify_sigmoid
    
    
        def forward(self, x):
            x = self.FeatureExtraction_onestep(x)
            x = x.view((x.shape[0], 1, x.shape[1], -1))
            x = self.FeatureExtraction_maxpool(x)
            x = x.view(x.shape[0], x.shape[2], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))
            # print(x.shape)
    
            x = self.MultiScaleMapping(x)
            # print(x.shape)
            x = x.view(x.shape[0], -1)
    
            x = self.FeatureIntegration(x)
            out = self.NonLinearRegression(x)
    
            return out
    
    
    if __name__ == '__main__':
        a = torch.randn(1,3,224,224)
        net = TVRNet(3,3)
        b = net(a)
        print(b.shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
  • 相关阅读:
    索引介绍及索引的分类
    Java Maven Tomcat使用Tesseract-OCR文字识别(Tess4j)
    Yii2 init 初始化脚本分析
    手机怎么把图片转换成Word?这个小妙招大家要学会
    docker-compose:搭建酷炫私有云相册photoprism
    【免费源码下载】完美运营版商城 虚拟商品全功能商城 全能商城小程序 智慧商城系统 全品类百货商城php+uniapp
    Python编程陷阱(十一)
    npm和package.json
    OpenCV学习(二)——OpenCV中绘图功能
    【UV打印机】PrintExp打印软件教程(五)-高级
  • 原文地址:https://blog.csdn.net/qq_38734327/article/details/134080834