• 【快捷测试模型是否可以跑通】设置一张图片的张量形式,送入自己写的模型进行测试


    文章目录


    摘要:通过模拟图片数据,送入模型,来检验模型是否能够跑通,该方法执行代码调试尤其方便。一个非常好的优点:无需设置数据集,也无需繁琐的参数设置,只需要模拟一个batch的数据,直接送入模型进行测试,简单快捷,能节省大量的时间。

    步骤

    ①假设我编写的模型名称叫做:ESA_blcok

    import torch.nn as nn
    import torch
    
    class ESA_blcok(nn.Module):  # 这段代码只是模版,具体内容根据自己的模型来编程
        def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
            super().__init__()
            self.ESAlayer = ...
            self.ff = ...
    
        def forward(self, x):
            ...
            ...
            ...
    
            return out+x   # 编写完自己模型了
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    ②想要验证自己写的代码是否可以跑通,可以先设置一个张量,比如我设置送入模型的图片shape:(2,3,512,512),即batchsize=2,通道数为3,图片尺寸为512×512,可以用一行代码来生成:

    	input = torch.rand((4, 3, 320, 320))  # (B,C,H,W)
    
    • 1

    ③声明模型:

        esa = ESA_blcok(dim=3) # ESA_blcok就是你自己编写的模型
    
    • 1

    ④将图片送入模型:

        output = esa(x)
    
    • 1

    ⑤最后打印输出,检查是否有模型输出结果:

        print(output.shape)
    
    • 1

    完整流程如下:

    import torch.nn as nn
    import torch
    
    class ESA_blcok(nn.Module):  # 这段代码只是模版,具体内容根据自己的模型来编程
        def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
            super().__init__()
            self.ESAlayer = ...
            self.ff = ...
    
        def forward(self, x):
            ...
            ...
            ...
    
            return out+x   # 编写完自己模型了
    
    
    # 开始验证模型是否能跑通
    # if __name__ == '__main__':程序从改行代码开始运行
    if __name__ == '__main__':
        input = torch.rand((4, 3, 320, 320))
        esa = ESA_blcok(dim=3)
        output = esa(x)
        print(output.shape)  # 如果有输出,说明模型跑通了
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    完整的例子

    如果想体验完整的过程,下面是一个完整的示例,能够直接运行:

    
    # 这里测试了两个模型,大家不必关心模型的具体实现,只要掌握方法即可
    import torch.nn as nn
    import torch
    from einops import rearrange, repeat
    from einops.layers.torch import Rearrange
    import torch.nn.functional as F
    
    
    class PreNorm(nn.Module):
        def __init__(self, dim, fn):
            super().__init__()
            self.norm = nn.LayerNorm(dim)
            self.fn = fn
    
        def forward(self, x, **kwargs):
            return self.fn(self.norm(x), **kwargs)
    
    
    class FeedForward(nn.Module):
        def __init__(self, dim, hidden_dim, dropout=0.):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, dim),
                nn.Dropout(dropout)
            )
    
        def forward(self, x):
            return self.net(x)
    
    
    class PPM(nn.Module):
        def __init__(self, pooling_sizes=(1, 3, 5)):
            super().__init__()
            self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size, size)) for size in pooling_sizes])
    
        def forward(self, feat):
            b, c, h, w = feat.shape
            output = [layer(feat).view(b, c, -1) for layer in self.layer]
            output = torch.cat(output, dim=-1)
            return output
    
    
    # Efficient self attention
    class ESA_layer(nn.Module):
        def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
            super().__init__()
            inner_dim = dim_head * heads
            project_out = not (heads == 1 and dim_head == dim)
    
            self.heads = heads
            self.scale = dim_head ** -0.5
    
            self.attend = nn.Softmax(dim=-1)
            self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)
            self.ppm = PPM(pooling_sizes=(1, 3, 5))
            self.to_out = nn.Sequential(
                nn.Linear(inner_dim, dim),
                nn.Dropout(dropout)
            ) if project_out else nn.Identity()
    
        def forward(self, x):
            # input x (b, c, h, w)
            b, c, h, w = x.shape
            q, k, v = self.to_qkv(x).chunk(3, dim=1)  # q/k/v shape: (b, inner_dim, h, w)
            q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)  # q shape: (b, head, n_q, d)
    
            k, v = self.ppm(k), self.ppm(v)  # k/v shape: (b, inner_dim, n_kv)
            k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)  # k shape: (b, head, n_kv, d)
            v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)  # v shape: (b, head, n_kv, d)
    
            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # shape: (b, head, n_q, n_kv)
    
            attn = self.attend(dots)
    
            out = torch.matmul(attn, v)  # shape: (b, head, n_q, d)
            out = rearrange(out, 'b head n d -> b n (head d)')
            return self.to_out(out)
    
    
    class ESA_blcok(nn.Module):
        def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
            super().__init__()
            self.ESAlayer = ESA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)
            self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
    
        def forward(self, x):
            b, c, h, w = x.shape
            out = rearrange(x, 'b c h w -> b (h w) c')
            out = self.ESAlayer(x) + out
            out = self.ff(out) + out
            out = rearrange(out, 'b (h w) c -> b c h w', h=h)
    
            return out+x
            # return out
    
    
    def MaskAveragePooling(x, mask):
        mask = torch.sigmoid(mask)
        b, c, h, w = x.shape
        eps = 0.0005
        x_mask = x * mask
        h, w = x.shape[2], x.shape[3]
        area = F.avg_pool2d(mask, (h, w)) * h * w + eps
        x_feat = F.avg_pool2d(x_mask, (h, w)) * h * w / area
        x_feat = x_feat.view(b, c, -1)
        return x_feat
    
    
    # Lesion-aware Cross Attention
    class LCA_layer(nn.Module):
        def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
            super().__init__()
            inner_dim = dim_head * heads
            project_out = not (heads == 1 and dim_head == dim)
            self.heads = heads
            self.scale = dim_head ** -0.5
    
            self.attend = nn.Softmax(dim=-1)
            self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)
            self.to_out = nn.Sequential(
                nn.Linear(inner_dim, dim),
                nn.Dropout(dropout)
            ) if project_out else nn.Identity()
    
        def forward(self, x, mask):
            # input x (b, c, h, w)
            b, c, h, w = x.shape
            q, k, v = self.to_qkv(x).chunk(3, dim=1)  # q/k/v shape: (b, inner_dim, h, w)
            q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)  # q shape: (b, head, n_q, d)
    
            k, v = MaskAveragePooling(k, mask), MaskAveragePooling(v, mask)  # k/v shape: (b, inner_dim, 1)
            k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)  # k shape: (b, head, 1, d)
            v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)  # v shape: (b, head, 1, d)
    
            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # shape: (b, head, n_q, n_kv)
    
            attn = self.attend(dots)
    
            out = torch.matmul(attn, v)  # shape: (b, head, n_q, d)
            out = rearrange(out, 'b head n d -> b n (head d)')
            return self.to_out(out)
    
    
    class LCA_blcok(nn.Module):
        def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
            super().__init__()
            self.LCAlayer = LCA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)
            self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
    
        def forward(self, x, mask):
            b, c, h, w = x.shape
            out = rearrange(x, 'b c h w -> b (h w) c')
            out = self.LCAlayer(x, mask) + out
            out = self.ff(out) + out
            out = rearrange(out, 'b (h w) c -> b c h w', h=h)
    
            return out
    
    
    # test
    if __name__ == '__main__':
        x = torch.rand((4, 3, 320, 320))
        mask = torch.rand(4, 1, 320, 320)
        lca = LCA_blcok(dim=3)
        esa = ESA_blcok(dim=3)
        print(lca(x, mask).shape)
        print(esa(x).shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
  • 相关阅读:
    Linux 基础-查看和设置环境变量
    1166 Summit – PAT甲级真题
    java 异步发展史 Runnable Callable Future CompletableFuture
    9 HDFS架构剖析
    Apache shenyu,Java 微服务网关的首选
    [附源码]计算机毕业设计基于SpringBoot智能家电商城
    JS(JavaScript)入门语法,语法+示例,非常详细!!!
    yolov5调用zed相机实现三维社交距离检测(单类别)
    一文教你如何发挥好 TDengine Grafana 插件作用
    如何在IIS7里设置实现访问.txt文件是下载模式
  • 原文地址:https://blog.csdn.net/weixin_44883789/article/details/133934068