• 注意力机制讲解与代码解析


    一、SEBlock(通道注意力机制)

    先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
    (B, C, H, W)---- (B, C, 1, 1)

    利用各channel维度的相关性计算权重
    (B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid

    与原特征相乘得到加权后的。

    1. import torch
    2. import torch.nn as nn
    3. class SELayer(nn.Module):
    4. def __init__(self, channel, reduction = 4):
    5. super(SELayer, self).__init__()
    6. self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小
    7. self.fc1 = nn.Sequential(
    8. nn.Conv2d(channel, channel//reduction, 1, bias = False),
    9. nn.ReLu(implace = True),
    10. nn.Conv2d(channel//reduction, channel, 1, bias = False),
    11. nn.sigmoid()
    12. )
    13. def forward(self, x):
    14. y = self.avg_pool(x)
    15. y_out = self.fc1(y)
    16. return x * y_out

    二、CBAM(通道注意力+空间注意力机制)

    CBAM里面既有通道注意力机制,也有空间注意力机制。
    通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。

    空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。 

    1. import torch
    2. import torch.nn as nn
    3. class ChannelAttention(nn.Module):
    4. def __init__(self, channel, rate = 4):
    5. super(ChannelAttention, self).__init__()
    6. self.avg_pool = nn.AdaptiveAvgPool2d(1)
    7. self.max_pool = nn.AdaptiveMaxPool2d(1)
    8. self.fc1 = nn.Sequential(
    9. nn.Conv2d(channel, channel//rate, 1, bias = False)
    10. nn.ReLu(implace = True)
    11. nn.Conv2d(channel//rate, channel, 1, bias = False)
    12. )
    13. self.sig = nn.sigmoid()
    14. def forward(self, x):
    15. avg = sefl.avg_pool(x)
    16. avg_feature = self.fc1(avg)
    17. max = self.max_pool(x)
    18. max_feature = self.fc1(max)
    19. out = max_feature + avg_feature
    20. out = self.sig(out)
    21. return x * out

    1. import torch
    2. import torch.nn as nn
    3. class SpatialAttention(nn.Module):
    4. def __init__(self):
    5. super(SpatialAttention, self).__init__()
    6. //(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)
    7. self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)
    8. self.sigmoid = nn.sigmoid()
    9. def forward(self, x):
    10. mean_f = torch.mean(x, dim = 1, keepdim = True)
    11. max_f = torch.max(x, dim = 1, keepdim = True).values
    12. cat = torch.cat([mean_f, max_f], dim = 1)
    13. out = self.conv1(cat)
    14. return x*self.sigmod(out)

    三、transformer里的注意力机制 

    Scaled Dot-Product Attention

    该注意力机制的输入是QKV。

    1.先Q,K相乘。

    2.scale

    3.softmax

    4.求output

    1. import torch
    2. import torch.nn as nn
    3. class ScaledDotProductAttention(nn.Module):
    4. def __init__(self, scale):
    5. super(ScaledDotProductAttention, self)
    6. self.scale = scale
    7. self.softmax = nn.softmax(dim = 2)
    8. def forward(self, q, k, v):
    9. u = torch.bmm(q, k.transpose(1, 2))
    10. u = u / scale
    11. attn = self.softmax(u)
    12. output = torch.bmm(attn, v)
    13. return output
    14. scale = np.power(d_k, 0.5) //缩放系数为K维度的根号。
    15. //Q (B, n_q, d_q) , K (B, n_k, d_k) V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。

     MultiHeadAttention

    将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。

    1.QKV单头变多头  channel ----- n * new_channel通过linear变换,然后把head和batch先合并

    2.求单头注意力机制输出

    3.维度拆分   将最终的head和channel合并。

    4.linear得到最终输出维度

    1. import torch
    2. import torch.nn as nn
    3. class MultiHeadAttention(nn.Module):
    4. def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):
    5. super(MultiHeadAttention, self)
    6. self.n_head = n_head
    7. self.d_k = d_k
    8. self.d_v = d_v
    9. self.fc_k = nn.Linear(d_k_, n_head * d_k)
    10. self.fc_v = nn.Linear(d_v_, n_head * d_v)
    11. self.fc_q = nn.Linear(d_k_, n_head * d_k)
    12. self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
    13. self.fc_o = nn.Linear(n_head * d_v, d_0)
    14. def forward(self, q, k, v):
    15. batch, n_q, d_q_ = q.size()
    16. batch, n_k, d_k_ = k.size()
    17. batch, n_v, d_v_ = v.size()
    18. q = self.fc_q(q)
    19. k = self.fc_k(k)
    20. v = self.fc_v(v)
    21. q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
    22. k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
    23. v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v)
    24. output = self.attention(q, k, v)
    25. output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
    26. output = self.fc_0(output)
    27. return output

  • 相关阅读:
    谷粒学苑项目前台界面 (一)
    如何恢复删除的文件?4种常用方法教你恢复被删除的文件
    3. 基本数据类型
    JAVA个人理财系统计算机毕业设计Mybatis+系统+数据库+调试部署
    企业电子招标采购系统源码Spring Boot + Mybatis + Redis + Layui + 前后端分离 构建企业电子招采平台之立项流程图
    UML设计系列(7):UML设计阶段性总结
    【Docker】Harbor私有仓库与管理
    【screen】screen命令 使用小记
    练习8:多重子查询
    python连接orcal数据库以及解决1047报错方法(已解决)
  • 原文地址:https://blog.csdn.net/slamer111/article/details/132788865