• hrformer


    1. # --------------------------------------------------------
    2. # High Resolution Transformer
    3. # Copyright (c) 2021 Microsoft
    4. # Licensed under The MIT License [see LICENSE for details]
    5. # Written by Rao Fu, RainbowSecret
    6. # --------------------------------------------------------
    7. import os
    8. import math
    9. import logging
    10. import torch
    11. import torch.nn as nn
    12. from functools import partial
    13. from mmcv.cnn import build_conv_layer, build_norm_layer
    14. BN_MOMENTUM = 0.1
    15. # --------------------------------------------------------
    16. # Copyright (c) 2021 Microsoft
    17. # Licensed under The MIT License [see LICENSE for details]
    18. # Modified by Lang Huang, RainbowSecret from:
    19. # https://github.com/openseg-group/openseg.pytorch/blob/master/lib/models/modules/isa_block.py
    20. # --------------------------------------------------------
    21. import os
    22. import pdb
    23. import math
    24. import torch
    25. import torch.nn as nn
    26. # --------------------------------------------------------
    27. # Copyright (c) 2021 Microsoft
    28. # Licensed under The MIT License [see LICENSE for details]
    29. # Modified by Lang Huang, RainbowSecret from:
    30. # https://github.com/openseg-group/openseg.pytorch/blob/master/lib/models/modules/isa_block.py
    31. # --------------------------------------------------------
    32. import copy
    33. import math
    34. import warnings
    35. import torch
    36. from torch import nn, Tensor
    37. from torch.nn import functional as F
    38. from torch._jit_internal import Optional, Tuple
    39. from torch.overrides import has_torch_function, handle_torch_function
    40. from torch.nn.functional import linear, pad, softmax, dropout
    41. from einops import rearrange
    42. from timm.models.layers import to_2tuple, trunc_normal_
    43. # --------------------------------------------------------
    44. # Copyright (c) 2021 Microsoft
    45. # Licensed under The MIT License [see LICENSE for details]
    46. # Modified by RainbowSecret from:
    47. # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L852
    48. # --------------------------------------------------------
    49. import copy
    50. import math
    51. import warnings
    52. import torch
    53. import torch.nn.functional as F
    54. from torch import nn, Tensor
    55. from torch.nn.modules.module import Module
    56. from torch._jit_internal import Optional, Tuple
    57. from torch.overrides import has_torch_function, handle_torch_function
    58. from torch.nn.functional import linear, pad, softmax, dropout
    59. class MultiheadAttention(Module):
    60. bias_k: Optional[torch.Tensor]
    61. bias_v: Optional[torch.Tensor]
    62. def __init__(
    63. self,
    64. embed_dim,
    65. num_heads,
    66. dropout=0.0,
    67. bias=True,
    68. add_bias_kv=False,
    69. add_zero_attn=False,
    70. kdim=None,
    71. vdim=None,
    72. ):
    73. super(MultiheadAttention, self).__init__()
    74. self.embed_dim = embed_dim
    75. self.kdim = kdim if kdim is not None else embed_dim
    76. self.vdim = vdim if vdim is not None else embed_dim
    77. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
    78. self.num_heads = num_heads
    79. self.dropout = dropout
    80. self.head_dim = embed_dim // num_heads
    81. assert (
    82. self.head_dim * num_heads == self.embed_dim
    83. ), "embed_dim must be divisible by num_heads"
    84. self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
    85. self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
    86. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    87. self.out_proj = nn.Linear(embed_dim, embed_dim)
    88. self.in_proj_bias = None
    89. self.in_proj_weight = None
    90. self.bias_k = self.bias_v = None
    91. self.q_proj_weight = None
    92. self.k_proj_weight = None
    93. self.v_proj_weight = None
    94. self.add_zero_attn = add_zero_attn
    95. def __setstate__(self, state):
    96. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
    97. if "_qkv_same_embed_dim" not in state:
    98. state["_qkv_same_embed_dim"] = True
    99. super(MultiheadAttention, self).__setstate__(state)
    100. def forward(
    101. self,
    102. query,
    103. key,
    104. value,
    105. key_padding_mask=None,
    106. need_weights=False,
    107. attn_mask=None,
    108. residual_attn=None,
    109. ):
    110. if not self._qkv_same_embed_dim:
    111. return self.multi_head_attention_forward(
    112. query,
    113. key,
    114. value,
    115. self.embed_dim,
    116. self.num_heads,
    117. self.in_proj_weight,
    118. self.in_proj_bias,
    119. self.bias_k,
    120. self.bias_v,
    121. self.add_zero_attn,
    122. self.dropout,
    123. self.out_proj.weight,
    124. self.out_proj.bias,
    125. training=self.training,
    126. key_padding_mask=key_padding_mask,
    127. need_weights=need_weights,
    128. attn_mask=attn_mask,
    129. use_separate_proj_weight=True,
    130. q_proj_weight=self.q_proj_weight,
    131. k_proj_weight=self.k_proj_weight,
    132. v_proj_weight=self.v_proj_weight,
    133. out_dim=self.vdim,
    134. residual_attn=residual_attn,
    135. )
    136. else:
    137. return self.multi_head_attention_forward(
    138. query,
    139. key,
    140. value,
    141. self.embed_dim,
    142. self.num_heads,
    143. self.in_proj_weight,
    144. self.in_proj_bias,
    145. self.bias_k,
    146. self.bias_v,
    147. self.add_zero_attn,
    148. self.dropout,
    149. self.out_proj.weight,
    150. self.out_proj.bias,
    151. training=self.training,
    152. key_padding_mask=key_padding_mask,
    153. need_weights=need_weights,
    154. attn_mask=attn_mask,
    155. out_dim=self.vdim,
    156. residual_attn=residual_attn,
    157. )
    158. def multi_head_attention_forward(
    159. self,
    160. query: Tensor,
    161. key: Tensor,
    162. value: Tensor,
    163. embed_dim_to_check: int,
    164. num_heads: int,
    165. in_proj_weight: Tensor,
    166. in_proj_bias: Tensor,
    167. bias_k: Optional[Tensor],
    168. bias_v: Optional[Tensor],
    169. add_zero_attn: bool,
    170. dropout_p: float,
    171. out_proj_weight: Tensor,
    172. out_proj_bias: Tensor,
    173. training: bool = True,
    174. key_padding_mask: Optional[Tensor] = None,
    175. need_weights: bool = False,
    176. attn_mask: Optional[Tensor] = None,
    177. use_separate_proj_weight: bool = False,
    178. q_proj_weight: Optional[Tensor] = None,
    179. k_proj_weight: Optional[Tensor] = None,
    180. v_proj_weight: Optional[Tensor] = None,
    181. static_k: Optional[Tensor] = None,
    182. static_v: Optional[Tensor] = None,
    183. out_dim: Optional[Tensor] = None,
    184. residual_attn: Optional[Tensor] = None,
    185. ) -> Tuple[Tensor, Optional[Tensor]]:
    186. if not torch.jit.is_scripting():
    187. tens_ops = (
    188. query,
    189. key,
    190. value,
    191. in_proj_weight,
    192. in_proj_bias,
    193. bias_k,
    194. bias_v,
    195. out_proj_weight,
    196. out_proj_bias,
    197. )
    198. if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(
    199. tens_ops
    200. ):
    201. return handle_torch_function(
    202. multi_head_attention_forward,
    203. tens_ops,
    204. query,
    205. key,
    206. value,
    207. embed_dim_to_check,
    208. num_heads,
    209. in_proj_weight,
    210. in_proj_bias,
    211. bias_k,
    212. bias_v,
    213. add_zero_attn,
    214. dropout_p,
    215. out_proj_weight,
    216. out_proj_bias,
    217. training=training,
    218. key_padding_mask=key_padding_mask,
    219. need_weights=need_weights,
    220. attn_mask=attn_mask,
    221. use_separate_proj_weight=use_separate_proj_weight,
    222. q_proj_weight=q_proj_weight,
    223. k_proj_weight=k_proj_weight,
    224. v_proj_weight=v_proj_weight,
    225. static_k=static_k,
    226. static_v=static_v,
    227. )
    228. tgt_len, bsz, embed_dim = query.size()
    229. key = query if key is None else key
    230. value = query if value is None else value
    231. assert embed_dim == embed_dim_to_check
    232. # allow MHA to have different sizes for the feature dimension
    233. assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
    234. head_dim = embed_dim // num_heads
    235. v_head_dim = out_dim // num_heads
    236. assert (
    237. head_dim * num_heads == embed_dim
    238. ), "embed_dim must be divisible by num_heads"
    239. scaling = float(head_dim) ** -0.5
    240. q = self.q_proj(query) * scaling
    241. k = self.k_proj(key)
    242. v = self.v_proj(value)
    243. if attn_mask is not None:
    244. assert (
    245. attn_mask.dtype == torch.float32
    246. or attn_mask.dtype == torch.float64
    247. or attn_mask.dtype == torch.float16
    248. or attn_mask.dtype == torch.uint8
    249. or attn_mask.dtype == torch.bool
    250. ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
    251. attn_mask.dtype
    252. )
    253. if attn_mask.dtype == torch.uint8:
    254. warnings.warn(
    255. "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
    256. )
    257. attn_mask = attn_mask.to(torch.bool)
    258. if attn_mask.dim() == 2:
    259. attn_mask = attn_mask.unsqueeze(0)
    260. if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
    261. raise RuntimeError("The size of the 2D attn_mask is not correct.")
    262. elif attn_mask.dim() == 3:
    263. if list(attn_mask.size()) != [
    264. bsz * num_heads,
    265. query.size(0),
    266. key.size(0),
    267. ]:
    268. raise RuntimeError("The size of the 3D attn_mask is not correct.")
    269. else:
    270. raise RuntimeError(
    271. "attn_mask's dimension {} is not supported".format(attn_mask.dim())
    272. )
    273. # convert ByteTensor key_padding_mask to bool
    274. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
    275. warnings.warn(
    276. "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
    277. )
    278. key_padding_mask = key_padding_mask.to(torch.bool)
    279. q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    280. if k is not None:
    281. k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    282. if v is not None:
    283. v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
    284. src_len = k.size(1)
    285. if key_padding_mask is not None:
    286. assert key_padding_mask.size(0) == bsz
    287. assert key_padding_mask.size(1) == src_len
    288. if add_zero_attn:
    289. src_len += 1
    290. k = torch.cat(
    291. [
    292. k,
    293. torch.zeros(
    294. (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
    295. ),
    296. ],
    297. dim=1,
    298. )
    299. v = torch.cat(
    300. [
    301. v,
    302. torch.zeros(
    303. (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
    304. ),
    305. ],
    306. dim=1,
    307. )
    308. if attn_mask is not None:
    309. attn_mask = pad(attn_mask, (0, 1))
    310. if key_padding_mask is not None:
    311. key_padding_mask = pad(key_padding_mask, (0, 1))
    312. attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    313. assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
    314. """
    315. Attention weight for the invalid region is -inf
    316. """
    317. if attn_mask is not None:
    318. if attn_mask.dtype == torch.bool:
    319. attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    320. else:
    321. attn_output_weights += attn_mask
    322. if key_padding_mask is not None:
    323. attn_output_weights = attn_output_weights.view(
    324. bsz, num_heads, tgt_len, src_len
    325. )
    326. attn_output_weights = attn_output_weights.masked_fill(
    327. key_padding_mask.unsqueeze(1).unsqueeze(2),
    328. float("-inf"),
    329. )
    330. attn_output_weights = attn_output_weights.view(
    331. bsz * num_heads, tgt_len, src_len
    332. )
    333. if residual_attn is not None:
    334. attn_output_weights = attn_output_weights.view(
    335. bsz, num_heads, tgt_len, src_len
    336. )
    337. attn_output_weights += residual_attn.unsqueeze(0)
    338. attn_output_weights = attn_output_weights.view(
    339. bsz * num_heads, tgt_len, src_len
    340. )
    341. """
    342. Reweight the attention map before softmax().
    343. attn_output_weights: (b*n_head, n, hw)
    344. """
    345. attn_output_weights = softmax(attn_output_weights, dim=-1)
    346. attn_output_weights = dropout(
    347. attn_output_weights, p=dropout_p, training=training
    348. )
    349. attn_output = torch.bmm(attn_output_weights, v)
    350. assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
    351. attn_output = (
    352. attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
    353. )
    354. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
    355. if need_weights:
    356. # average attention weights over heads
    357. attn_output_weights = attn_output_weights.view(
    358. bsz, num_heads, tgt_len, src_len
    359. )
    360. return attn_output, attn_output_weights.sum(dim=1) / num_heads
    361. else:
    362. return attn_output
    363. class MHA_(MultiheadAttention):
    364. #Multihead Attention with extra flags on the q/k/v and out projections.
    365. bias_k: Optional[torch.Tensor]
    366. bias_v: Optional[torch.Tensor]
    367. def __init__(self, *args, rpe=False, window_size=7, **kwargs):
    368. super(MHA_, self).__init__(*args, **kwargs)
    369. self.rpe = rpe
    370. if rpe:
    371. self.window_size = [window_size] * 2
    372. # define a parameter table of relative position bias
    373. self.relative_position_bias_table = nn.Parameter(
    374. torch.zeros(
    375. (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
    376. self.num_heads,
    377. )
    378. ) # 2*Wh-1 * 2*Ww-1, nH
    379. # get pair-wise relative position index for each token inside the window
    380. coords_h = torch.arange(self.window_size[0])
    381. coords_w = torch.arange(self.window_size[1])
    382. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    383. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    384. relative_coords = (
    385. coords_flatten[:, :, None] - coords_flatten[:, None, :]
    386. ) # 2, Wh*Ww, Wh*Ww
    387. relative_coords = relative_coords.permute(
    388. 1, 2, 0
    389. ).contiguous() # Wh*Ww, Wh*Ww, 2
    390. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    391. relative_coords[:, :, 1] += self.window_size[1] - 1
    392. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    393. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    394. self.register_buffer("relative_position_index", relative_position_index)
    395. trunc_normal_(self.relative_position_bias_table, std=0.02)
    396. def forward(
    397. self,
    398. query,
    399. key,
    400. value,
    401. key_padding_mask=None,
    402. need_weights=False,
    403. attn_mask=None,
    404. do_qkv_proj=True,
    405. do_out_proj=True,
    406. rpe=True,
    407. ):
    408. if not self._qkv_same_embed_dim:
    409. return self.multi_head_attention_forward(
    410. query,
    411. key,
    412. value,
    413. self.embed_dim,
    414. self.num_heads,
    415. self.in_proj_weight,
    416. self.in_proj_bias,
    417. self.bias_k,
    418. self.bias_v,
    419. self.add_zero_attn,
    420. self.dropout,
    421. self.out_proj.weight,
    422. self.out_proj.bias,
    423. training=self.training,
    424. key_padding_mask=key_padding_mask,
    425. need_weights=need_weights,
    426. attn_mask=attn_mask,
    427. use_separate_proj_weight=True,
    428. q_proj_weight=self.q_proj_weight,
    429. k_proj_weight=self.k_proj_weight,
    430. v_proj_weight=self.v_proj_weight,
    431. out_dim=self.vdim,
    432. do_qkv_proj=do_qkv_proj,
    433. do_out_proj=do_out_proj,
    434. rpe=rpe,
    435. )
    436. else:
    437. return self.multi_head_attention_forward(
    438. query,
    439. key,
    440. value,
    441. self.embed_dim,
    442. self.num_heads,
    443. self.in_proj_weight,
    444. self.in_proj_bias,
    445. self.bias_k,
    446. self.bias_v,
    447. self.add_zero_attn,
    448. self.dropout,
    449. self.out_proj.weight,
    450. self.out_proj.bias,
    451. training=self.training,
    452. key_padding_mask=key_padding_mask,
    453. need_weights=need_weights,
    454. attn_mask=attn_mask,
    455. out_dim=self.vdim,
    456. do_qkv_proj=do_qkv_proj,
    457. do_out_proj=do_out_proj,
    458. rpe=rpe,
    459. )
    460. def multi_head_attention_forward(
    461. self,
    462. query: Tensor,
    463. key: Tensor,
    464. value: Tensor,
    465. embed_dim_to_check: int,
    466. num_heads: int,
    467. in_proj_weight: Tensor,
    468. in_proj_bias: Tensor,
    469. bias_k: Optional[Tensor],
    470. bias_v: Optional[Tensor],
    471. add_zero_attn: bool,
    472. dropout_p: float,
    473. out_proj_weight: Tensor,
    474. out_proj_bias: Tensor,
    475. training: bool = True,
    476. key_padding_mask: Optional[Tensor] = None,
    477. need_weights: bool = False,
    478. attn_mask: Optional[Tensor] = None,
    479. use_separate_proj_weight: bool = False,
    480. q_proj_weight: Optional[Tensor] = None,
    481. k_proj_weight: Optional[Tensor] = None,
    482. v_proj_weight: Optional[Tensor] = None,
    483. static_k: Optional[Tensor] = None,
    484. static_v: Optional[Tensor] = None,
    485. out_dim: Optional[Tensor] = None,
    486. do_qkv_proj: bool = True,
    487. do_out_proj: bool = True,
    488. rpe=True,
    489. ) -> Tuple[Tensor, Optional[Tensor]]:
    490. if not torch.jit.is_scripting():
    491. tens_ops = (
    492. query,
    493. key,
    494. value,
    495. in_proj_weight,
    496. in_proj_bias,
    497. bias_k,
    498. bias_v,
    499. out_proj_weight,
    500. out_proj_bias,
    501. )
    502. if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(
    503. tens_ops
    504. ):
    505. return handle_torch_function(
    506. multi_head_attention_forward,
    507. tens_ops,
    508. query,
    509. key,
    510. value,
    511. embed_dim_to_check,
    512. num_heads,
    513. in_proj_weight,
    514. in_proj_bias,
    515. bias_k,
    516. bias_v,
    517. add_zero_attn,
    518. dropout_p,
    519. out_proj_weight,
    520. out_proj_bias,
    521. training=training,
    522. key_padding_mask=key_padding_mask,
    523. need_weights=need_weights,
    524. attn_mask=attn_mask,
    525. use_separate_proj_weight=use_separate_proj_weight,
    526. q_proj_weight=q_proj_weight,
    527. k_proj_weight=k_proj_weight,
    528. v_proj_weight=v_proj_weight,
    529. static_k=static_k,
    530. static_v=static_v,
    531. )
    532. tgt_len, bsz, embed_dim = query.size()
    533. key = query if key is None else key
    534. value = query if value is None else value
    535. assert embed_dim == embed_dim_to_check
    536. # allow MHA to have different sizes for the feature dimension
    537. assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
    538. head_dim = embed_dim // num_heads
    539. v_head_dim = out_dim // num_heads
    540. assert (
    541. head_dim * num_heads == embed_dim
    542. ), "embed_dim must be divisible by num_heads"
    543. scaling = float(head_dim) ** -0.5
    544. # whether or not use the original query/key/value
    545. q = self.q_proj(query) * scaling if do_qkv_proj else query
    546. k = self.k_proj(key) if do_qkv_proj else key
    547. v = self.v_proj(value) if do_qkv_proj else value
    548. if attn_mask is not None:
    549. assert (
    550. attn_mask.dtype == torch.float32
    551. or attn_mask.dtype == torch.float64
    552. or attn_mask.dtype == torch.float16
    553. or attn_mask.dtype == torch.uint8
    554. or attn_mask.dtype == torch.bool
    555. ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
    556. attn_mask.dtype
    557. )
    558. if attn_mask.dtype == torch.uint8:
    559. warnings.warn(
    560. "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
    561. )
    562. attn_mask = attn_mask.to(torch.bool)
    563. if attn_mask.dim() == 2:
    564. attn_mask = attn_mask.unsqueeze(0)
    565. if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
    566. raise RuntimeError("The size of the 2D attn_mask is not correct.")
    567. elif attn_mask.dim() == 3:
    568. if list(attn_mask.size()) != [
    569. bsz * num_heads,
    570. query.size(0),
    571. key.size(0),
    572. ]:
    573. raise RuntimeError("The size of the 3D attn_mask is not correct.")
    574. else:
    575. raise RuntimeError(
    576. "attn_mask's dimension {} is not supported".format(attn_mask.dim())
    577. )
    578. # convert ByteTensor key_padding_mask to bool
    579. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
    580. warnings.warn(
    581. "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
    582. )
    583. key_padding_mask = key_padding_mask.to(torch.bool)
    584. q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    585. if k is not None:
    586. k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    587. if v is not None:
    588. v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
    589. src_len = k.size(1)
    590. if key_padding_mask is not None:
    591. assert key_padding_mask.size(0) == bsz
    592. assert key_padding_mask.size(1) == src_len
    593. if add_zero_attn:
    594. src_len += 1
    595. k = torch.cat(
    596. [
    597. k,
    598. torch.zeros(
    599. (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
    600. ),
    601. ],
    602. dim=1,
    603. )
    604. v = torch.cat(
    605. [
    606. v,
    607. torch.zeros(
    608. (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
    609. ),
    610. ],
    611. dim=1,
    612. )
    613. if attn_mask is not None:
    614. attn_mask = pad(attn_mask, (0, 1))
    615. if key_padding_mask is not None:
    616. key_padding_mask = pad(key_padding_mask, (0, 1))
    617. attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    618. assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
    619. """
    620. Add relative position embedding
    621. """
    622. if self.rpe and rpe:
    623. # NOTE: for simplicity, we assume the src_len == tgt_len == window_size**2 here
    624. # print('src, tar, window', src_len, tgt_len, self.window_size[0], self.window_size[1])
    625. # assert src_len == self.window_size[0] * self.window_size[1] \
    626. # and tgt_len == self.window_size[0] * self.window_size[1], \
    627. # f"src{src_len}, tgt{tgt_len}, window{self.window_size[0]}"
    628. relative_position_bias = self.relative_position_bias_table[
    629. self.relative_position_index.view(-1)
    630. ].view(
    631. self.window_size[0] * self.window_size[1],
    632. self.window_size[0] * self.window_size[1],
    633. -1,
    634. ) # Wh*Ww,Wh*Ww,nH
    635. relative_position_bias = relative_position_bias.permute(
    636. 2, 0, 1
    637. ).contiguous() # nH, Wh*Ww, Wh*Ww
    638. # HELLO!!!!!
    639. attn_output_weights = attn_output_weights.view(
    640. bsz, num_heads, tgt_len, src_len
    641. ) # + relative_position_bias.unsqueeze(0)
    642. attn_output_weights = attn_output_weights.view(
    643. bsz * num_heads, tgt_len, src_len
    644. )
    645. """
    646. Attention weight for the invalid region is -inf
    647. """
    648. if attn_mask is not None:
    649. if attn_mask.dtype == torch.bool:
    650. attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    651. else:
    652. attn_output_weights += attn_mask
    653. if key_padding_mask is not None:
    654. attn_output_weights = attn_output_weights.view(
    655. bsz, num_heads, tgt_len, src_len
    656. )
    657. attn_output_weights = attn_output_weights.masked_fill(
    658. key_padding_mask.unsqueeze(1).unsqueeze(2),
    659. float("-inf"),
    660. )
    661. attn_output_weights = attn_output_weights.view(
    662. bsz * num_heads, tgt_len, src_len
    663. )
    664. """
    665. Reweight the attention map before softmax().
    666. attn_output_weights: (b*n_head, n, hw)
    667. """
    668. attn_output_weights = softmax(attn_output_weights, dim=-1)
    669. attn_output_weights = dropout(
    670. attn_output_weights, p=dropout_p, training=training
    671. )
    672. attn_output = torch.bmm(attn_output_weights, v)
    673. assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
    674. attn_output = (
    675. attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
    676. )
    677. if do_out_proj:
    678. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
    679. if need_weights:
    680. # average attention weights over heads
    681. attn_output_weights = attn_output_weights.view(
    682. bsz, num_heads, tgt_len, src_len
    683. )
    684. return attn_output, q, k, attn_output_weights.sum(dim=1) / num_heads
    685. else:
    686. return attn_output, q, k # additionaly return the query and key
    687. class PadBlock(object):
    688. # """ "Make the size of feature map divisible by local group size."""
    689. def __init__(self, local_group_size=7):
    690. self.lgs = local_group_size
    691. if not isinstance(self.lgs, (tuple, list)):
    692. self.lgs = to_2tuple(self.lgs)
    693. assert len(self.lgs) == 2
    694. def pad_if_needed(self, x, size):
    695. n, h, w, c = size
    696. pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h
    697. pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w
    698. if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
    699. return F.pad(
    700. x,
    701. (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
    702. )
    703. return x
    704. def depad_if_needed(self, x, size):
    705. n, h, w, c = size
    706. pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h
    707. pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w
    708. if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
    709. return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :]
    710. return x
    711. class LocalPermuteModule(object):
    712. #""" "Permute the feature map to gather pixels in local groups, and the reverse #permutation"""
    713. def __init__(self, local_group_size=7):
    714. self.lgs = local_group_size
    715. if not isinstance(self.lgs, (tuple, list)):
    716. self.lgs = to_2tuple(self.lgs)
    717. assert len(self.lgs) == 2
    718. def permute(self, x, size):
    719. n, h, w, c = size
    720. return rearrange(
    721. x,
    722. "n (qh ph) (qw pw) c -> (ph pw) (n qh qw) c",
    723. n=n,
    724. qh=h // self.lgs[0],
    725. ph=self.lgs[0],
    726. qw=w // self.lgs[0],
    727. pw=self.lgs[0],
    728. c=c,
    729. )
    730. def rev_permute(self, x, size):
    731. n, h, w, c = size
    732. return rearrange(
    733. x,
    734. "(ph pw) (n qh qw) c -> n (qh ph) (qw pw) c",
    735. n=n,
    736. qh=h // self.lgs[0],
    737. ph=self.lgs[0],
    738. qw=w // self.lgs[0],
    739. pw=self.lgs[0],
    740. c=c,
    741. )
    742. class InterlacedPoolAttention(nn.Module):
    743. # r"""interlaced sparse multi-head self attention (ISA) module with relative position bias.
    744. Args:
    745. dim (int): Number of input channels.
    746. window_size (tuple[int]): Window size.
    747. num_heads (int): Number of attention heads.
    748. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
    749. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
    750. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
    751. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    752. # """
    753. def __init__(self, embed_dim, num_heads, window_size=7, rpe=True, **kwargs):
    754. super(InterlacedPoolAttention, self).__init__()
    755. self.dim = embed_dim
    756. self.num_heads = num_heads
    757. self.window_size = window_size
    758. self.with_rpe = rpe
    759. self.attn = MHA_(
    760. embed_dim, num_heads, rpe=rpe, window_size=window_size, **kwargs
    761. )
    762. self.pad_helper = PadBlock(window_size)
    763. self.permute_helper = LocalPermuteModule(window_size)
    764. def forward(self, x, H, W, **kwargs):
    765. B, N, C = x.shape
    766. x = x.view(B, H, W, C)
    767. print('x', x.shape)#x torch.Size([78, 48, 64, 78])
    768. # attention
    769. # pad
    770. x_pad = self.pad_helper.pad_if_needed(x, x.size())
    771. # print('x_pad', x_pad.shape)#x_pad torch.Size([78, 49, 70, 78])
    772. # permute
    773. x_permute = self.permute_helper.permute(x_pad, x_pad.size())
    774. # print('x_permute', x_permute.shape) # x_permute torch.Size([49, 5460, 78])
    775. # attention
    776. out, _, _ = self.attn(
    777. x_permute, x_permute, x_permute, rpe=self.with_rpe, **kwargs
    778. )
    779. # print('out', out.shape)#out torch.Size([49, 5460, 78])
    780. # reverse permutation
    781. out = self.permute_helper.rev_permute(out, x_pad.size())
    782. # print('out1', out.shape)#out1 torch.Size([78, 49, 70, 78])
    783. # de-pad, pooling with `ceil_mode=True` will do implicit padding, so we need to remove it, too
    784. out = self.pad_helper.depad_if_needed(out, x.size())
    785. # print('out.reshape(B, N, C)',out.reshape(B, N, C).shape)#out.reshape(B, N, C) torch.Size([2, 3072, 78])
    786. return out.reshape(B, N, C)
    787. def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    788. #"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    789. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    790. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    791. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    792. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    793. 'survival rate' as the argument.
    794. #"""
    795. if drop_prob == 0.0 or not training:
    796. return x
    797. keep_prob = 1 - drop_prob
    798. shape = (x.shape[0],) + (1,) * (
    799. x.ndim - 1
    800. ) # work with diff dim tensors, not just 2D ConvNets
    801. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    802. random_tensor.floor_() # binarize
    803. output = x.div(keep_prob) * random_tensor
    804. return output
    805. class DropPath(nn.Module):
    806. #"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    807. def __init__(self, drop_prob=None):
    808. super(DropPath, self).__init__()
    809. self.drop_prob = drop_prob
    810. def forward(self, x):
    811. return drop_path(x, self.drop_prob, self.training)
    812. def extra_repr(self):
    813. # (Optional)Set the extra information about this module. You can test
    814. # it by printing an object of this class.
    815. return "drop_prob={}".format(self.drop_prob)
    816. class MlpDWBN(nn.Module):
    817. def __init__(
    818. self,
    819. in_features,
    820. hidden_features=None,
    821. out_features=None,
    822. act_layer=nn.GELU,
    823. dw_act_layer=nn.GELU,
    824. drop=0.0,
    825. conv_cfg=None,
    826. norm_cfg=dict(type="BN", requires_grad=True),
    827. ):
    828. super().__init__()
    829. out_features = out_features or in_features
    830. hidden_features = hidden_features or in_features
    831. self.fc1 = build_conv_layer(
    832. conv_cfg,
    833. in_features,
    834. hidden_features,
    835. kernel_size=1,
    836. stride=1,
    837. padding=0,
    838. bias=True,
    839. )
    840. self.act1 = act_layer()
    841. self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
    842. self.dw3x3 = build_conv_layer(
    843. conv_cfg,
    844. hidden_features,
    845. hidden_features,
    846. kernel_size=3,
    847. stride=1,
    848. padding=1,
    849. groups=hidden_features,
    850. )
    851. self.act2 = dw_act_layer()
    852. self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
    853. self.fc2 = build_conv_layer(
    854. conv_cfg,
    855. hidden_features,
    856. out_features,
    857. kernel_size=1,
    858. stride=1,
    859. padding=0,
    860. bias=True,
    861. )
    862. self.act3 = act_layer()
    863. self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
    864. # self.drop = nn.Dropout(drop, inplace=True)
    865. def forward(self, x, H, W):
    866. if len(x.shape) == 3:
    867. B, N, C = x.shape
    868. if N == (H * W + 1):
    869. cls_tokens = x[:, 0, :]
    870. x_ = x[:, 1:, :].permute(0, 2, 1).contiguous().reshape(B, C, H, W)
    871. else:
    872. x_ = x.permute(0, 2, 1).contiguous().reshape(B, C, H, W)
    873. x_ = self.fc1(x_)
    874. x_ = self.norm1(x_)
    875. x_ = self.act1(x_)
    876. x_ = self.dw3x3(x_)
    877. x_ = self.norm2(x_)
    878. x_ = self.act2(x_)
    879. # x_ = self.drop(x_)
    880. x_ = self.fc2(x_)
    881. x_ = self.norm3(x_)
    882. x_ = self.act3(x_)
    883. # x_ = self.drop(x_)
    884. x_ = x_.reshape(B, C, -1).permute(0, 2, 1).contiguous()
    885. if N == (H * W + 1):
    886. x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1)
    887. else:
    888. x = x_
    889. return x
    890. elif len(x.shape) == 4:
    891. x = self.fc1(x)
    892. x = self.norm1(x)
    893. x = self.act1(x)
    894. x = self.dw3x3(x)
    895. x = self.norm2(x)
    896. x = self.act2(x)
    897. x = self.drop(x)
    898. x = self.fc2(x)
    899. x = self.norm3(x)
    900. x = self.act3(x)
    901. x = self.drop(x)
    902. return x
    903. else:
    904. raise RuntimeError("Unsupported input shape: {}".format(x.shape))
    905. class GeneralTransformerBlock(nn.Module):
    906. expansion = 1
    907. def __init__(
    908. self,
    909. inplanes,
    910. planes,
    911. num_heads,
    912. window_size=7,
    913. mlp_ratio=4.0,
    914. qkv_bias=True,
    915. qk_scale=None,
    916. drop=0.0,
    917. attn_drop=0.0,
    918. drop_path=0.0,
    919. act_layer=nn.GELU,
    920. norm_layer=partial(nn.LayerNorm, eps=1e-6),
    921. conv_cfg=None,
    922. norm_cfg=dict(type="BN", requires_grad=True),
    923. ):
    924. super().__init__()
    925. self.dim = inplanes
    926. self.out_dim = planes
    927. self.num_heads = num_heads
    928. self.window_size = window_size
    929. self.mlp_ratio = mlp_ratio
    930. self.conv_cfg = conv_cfg
    931. self.norm_cfg = norm_cfg
    932. self.attn = InterlacedPoolAttention(
    933. self.dim, num_heads=num_heads, window_size=window_size, dropout=attn_drop
    934. )
    935. self.norm1 = norm_layer(self.dim)
    936. self.norm2 = norm_layer(self.out_dim)
    937. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    938. mlp_hidden_dim = int(self.dim * mlp_ratio)
    939. self.mlp = MlpDWBN(
    940. in_features=self.dim,
    941. hidden_features=mlp_hidden_dim,
    942. out_features=self.out_dim,
    943. act_layer=act_layer,
    944. dw_act_layer=act_layer,
    945. drop=drop,
    946. conv_cfg=conv_cfg,
    947. norm_cfg=norm_cfg,
    948. )
    949. def forward(self, x):
    950. B, C, H, W = x.size()
    951. # reshape
    952. x = x.view(B, C, -1).permute(0, 2, 1).contiguous()
    953. # Attention
    954. x = x + self.drop_path(self.attn(self.norm1(x), H, W))
    955. # FFN
    956. x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
    957. # reshape
    958. x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
    959. return x
    960. a = torch.randn(2,78,48, 64)
    961. b = GeneralTransformerBlock(78,78,3)
    962. c = b(a)
    963. print('c',c.shape)

  • 相关阅读:
    重点用能单位能耗在线监测接入端系统
    select在socket中的server多路复用
    机器学习中 TP FP TN FN的概念
    缓存相关问题
    用python混合检索 + 重排序改善
    奥拉帕尼人血清白蛋白HSA纳米粒|陶扎色替卵清白蛋白OVA纳米粒|来他替尼小鼠血清白蛋白MSA纳米粒(试剂)
    python 实现两个文本文件内容去重
    C语言详细知识点复习(上)
    最长上升子序列
    算法——二叉树应用
  • 原文地址:https://blog.csdn.net/zouxiaolv/article/details/127732836