• 【论文笔记】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)


    论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

    论文链接:https://arxiv.org/abs/2404.02668

    code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

    2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

     以下是作者针对该部分进行改进的代码:

    1. def antidiagonal_gather(tensor):
    2. # 取出矩阵所有反斜向的元素并拼接
    3. B, C, H, W = tensor.size()
    4. shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
    5. index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W]
    6. # 扩展索引以适应B和C维度
    7. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    8. # 使用gather进行索引选择
    9. return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)
    10. def diagonal_gather(tensor):
    11. # 取出矩阵所有反斜向的元素并拼接
    12. B, C, H, W = tensor.size()
    13. shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
    14. index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W]
    15. # 扩展索引以适应B和C维度
    16. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    17. # 使用gather进行索引选择
    18. return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)
    19. def diagonal_scatter(tensor_flat, original_shape):
    20. # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
    21. B, C, H, W = original_shape
    22. shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
    23. index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W]
    24. # 扩展索引以适应B和C维度
    25. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    26. # 创建一个空的张量来存储反向散布的结果
    27. result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    28. # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
    29. tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    30. # 使用scatter_根据expanded_index将元素放回原位
    31. result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    32. return result_tensor
    33. def antidiagonal_scatter(tensor_flat, original_shape):
    34. # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
    35. B, C, H, W = original_shape
    36. shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
    37. index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W]
    38. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    39. # 初始化一个与原始张量形状相同、元素全为0的张量
    40. result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    41. # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
    42. tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    43. # 使用scatter_将元素根据索引放回原位
    44. result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    45. return result_tensor
    46. class CrossScan(torch.autograd.Function):
    47. # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
    48. @staticmethod
    49. def forward(ctx, x: torch.Tensor):
    50. B, C, H, W = x.shape
    51. ctx.shape = (B, C, H, W)
    52. # xs = x.new_empty((B, 4, C, H * W))
    53. xs = x.new_empty((B, 8, C, H * W))
    54. # 添加横向和竖向的扫描
    55. xs[:, 0] = x.flatten(2, 3)
    56. xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
    57. xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    58. # 提供斜向和反斜向的扫描
    59. xs[:, 4] = diagonal_gather(x)
    60. xs[:, 5] = antidiagonal_gather(x)
    61. xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
    62. return xs
    63. @staticmethod
    64. def backward(ctx, ys: torch.Tensor):
    65. # out: (b, k, d, l)
    66. B, C, H, W = ctx.shape
    67. L = H * W
    68. # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
    69. # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
    70. y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
    71. # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
    72. # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
    73. y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
    74. y_rb = y_rb.view(B, -1, H, W)
    75. # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
    76. y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
    77. # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
    78. y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))
    79. y_res = y_rb + y_da
    80. # return y.view(B, -1, H, W)
    81. return y_res
    82. class CrossMerge(torch.autograd.Function):
    83. @staticmethod
    84. def forward(ctx, ys: torch.Tensor):
    85. B, K, D, H, W = ys.shape
    86. ctx.shape = (H, W)
    87. ys = ys.view(B, K, D, -1)
    88. # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
    89. # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
    90. y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
    91. # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
    92. y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
    93. y_rb = y_rb.view(B, -1, H, W)
    94. # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
    95. y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
    96. # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
    97. y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))
    98. y_res = y_rb + y_da
    99. return y_res.view(B, D, -1)
    100. # return y
    101. @staticmethod
    102. def backward(ctx, x: torch.Tensor):
    103. # B, D, L = x.shape
    104. # out: (b, k, d, l)
    105. H, W = ctx.shape
    106. B, C, L = x.shape
    107. # xs = x.new_empty((B, 4, C, L))
    108. xs = x.new_empty((B, 8, C, L))
    109. # 横向和竖向扫描
    110. xs[:, 0] = x
    111. xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
    112. xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    113. # xs = xs.view(B, 4, C, H, W)
    114. # 提供斜向和反斜向的扫描
    115. xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
    116. xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
    117. xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
    118. # return xs
    119. return xs.view(B, 8, C, H, W)

  • 相关阅读:
    遗传算法python
    windows10配置paddleOCR的CPU版本总结
    adb使用笔记
    MySQL中char_length()和length()的区别
    在Windows上使用nginx具体步骤
    java计算机毕业设计足球赛会管理系统源程序+mysql+系统+lw文档+远程调试
    LLM 推理 - Nvidia TensorRT-LLM 与 Triton Inference Server
    Maven Wrapper 之 SpringBoot 项目下的 mvnw.cmd
    Kotlin 核心语法,为什么选择Kotlin ?
    selenium自动化测试神器
  • 原文地址:https://blog.csdn.net/qq_43456016/article/details/137873136