• ECCV2022细粒度图像检索SEMICON代码学习记录


    代码链接:GitHub - aassxun/SEMICON

    环境配置

    1. # 创建&激活虚拟环境
    2. conda create -n semicon python==3.8.5
    3. conda activate semicon
    4. # 安装相关依赖包 (该 pytorch 为无 gpu 版本)
    5. conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cpuonly -c pytorch
    6. pip install numpy==1.19.2
    7. pip install loguru==0.5.3
    8. pip install tqdm==4.54.1
    9. pip install pandas
    10. pip install scipy

    需要将 SEMICON_train.pySEMICON.pyHash_mAP.pybaseline_train.pybaseline.py 中的import models.resnet as resnet 和 from models.resnet import *改为 import models.resnet_torch as resnet 和 from models.resnet_torch import *

    下载CUB_200_2011数据集

    参考博客:CUB-200-2011鸟类数据集的下载与使用pytorch加载_景唯acr的博客-CSDN博客_cub200-2011

    代码运行

    1)训练

    python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --max-epoch 30 --gpu 0 --arch semicon --batch-size 16 --max-iter 40 --code-length 12,24,32,48 --lr 2.5e-4 --wd 1e-4 --optim SGD --lr-step 40 --num-samples 2000 --info 'CUB-SEMICON' --momen=0.91

    2)测试

    python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --gpu 0 --arch test --batch-size 16 --code-length 12,24,32,48 --wd 1e-4 --info 'CUB-SEMICON'

    如果不想使用 gpu,将参数 --gpu 设为False 即可

    代码学习

    1)固定随机种子

    YOLO-X 类似,将随机种子进行固定,后续实验将在此固定的随机种子下进行 (如消融实验等),增强了模型的可复现性 (但我觉得也只是仅限于特定的随机数,换另一个随机数可能结果又不一样了)。

    torch.backends.cudnn.deterministic 和 torch.backends.cudnn.benchmark:前者可以保证每次运行网络的时候相同输入的输出是固定的,后者为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定,网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

    参考博客:

    【pytorch】torch.backends.cudnn.deterministic_Xhfei1224的博客-CSDN博客_torch.backends.cudnn.deter

    torch.backends.cudnn.benchmark_Wanderer001的博客-CSDN博客_torch.backends.cudnn.benchmark

    1. def seed_everything(seed):
    2. random.seed(seed)
    3. os.environ['PYTHONHASHSEED'] = str(seed)
    4. np.random.seed(seed)
    5. torch.manual_seed(seed)
    6. torch.cuda.manual_seed(seed)
    7. torch.backends.cudnn.deterministic = True
    8. torch.backends.cudnn.benchmark = True
    9. seed_everything(68)

    2)数据加载

    对应脚本:data/cub_2011.py

    函数 load_data() 会返回 3 个 dataloader:query_dataloadertrain_dataloaderretrieval_dataloader

    1. # 划分训练、测试集
    2. Cub2011.init(root)
    3. # 定义查询、训练及检索数据集,涉及的数据增强在 data/transform.py 中
    4. query_dataset = Cub2011(root, 'query', query_transform())
    5. train_dataset = Cub2011(root, 'train', train_transform())
    6. retrieval_dataset = Cub2011(root, 'retrieval', query_transform())
    1. class Cub2011(Dataset):
    2. def __init__(self, root, mode, transform=None, loader=default_loader):
    3. self.root = os.path.expanduser(root)
    4. self.transform = transform
    5. self.loader = default_loader
    6. if mode == 'train':
    7. self.data = Cub2011.TRAIN_DATA
    8. self.targets = Cub2011.TRAIN_TARGETS
    9. elif mode == 'query':
    10. self.data = Cub2011.QUERY_DATA
    11. self.targets = Cub2011.QUERY_TARGETS
    12. elif mode == 'retrieval':
    13. self.data = Cub2011.RETRIEVAL_DATA
    14. self.targets = Cub2011.RETRIEVAL_TARGETS
    15. else:
    16. raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')
    17. @staticmethod
    18. def init(root):
    19. images = pd.read_csv(os.path.join(root, 'images.txt'), sep=' ',
    20. names=['img_id', 'filepath'])
    21. image_class_labels = pd.read_csv(os.path.join(root, 'image_class_labels.txt'),
    22. sep=' ', names=['img_id', 'target'])
    23. train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'),
    24. sep=' ', names=['img_id', 'is_training_img'])
    25. data = images.merge(image_class_labels, on='img_id')
    26. all_data = data.merge(train_test_split, on='img_id')
    27. all_data['filepath'] = 'images/' + all_data['filepath']
    28. train_data = all_data[all_data['is_training_img'] == 1]
    29. test_data = all_data[all_data['is_training_img'] == 0]
    30. # Split dataset
    31. Cub2011.QUERY_DATA = test_data['filepath'].to_numpy()
    32. Cub2011.QUERY_TARGETS = encode_onehot((test_data['target'] - 1).tolist(), 200)
    33. Cub2011.TRAIN_DATA = train_data['filepath'].to_numpy()
    34. Cub2011.TRAIN_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)
    35. Cub2011.RETRIEVAL_DATA = train_data['filepath'].to_numpy()
    36. Cub2011.RETRIEVAL_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)
    37. def get_onehot_targets(self):
    38. return torch.from_numpy(self.targets).float()
    39. def __len__(self):
    40. return len(self.data)
    41. def __getitem__(self, idx):
    42. img = Image.open(os.path.join(self.root, self.data[idx])).convert('RGB')
    43. if self.transform is not None:
    44. img = self.transform(img)
    45. return img, self.targets[idx], idx

    3)网络训练

    主干网络

    这里只使用了 resnet50 的前三个 layer,具体可查看 models/SEMICON.py 中的 ResNet_Backbone 类

    model = ResNet_Backbone(Bottleneck, [3, 4, 6], **kwargs)

    全局/局部转换网络

    1. class ResNet_Refine(nn.Module):
    2. def __init__(self, block, layer, is_local=True, num_classes=1000, zero_init_residual=False,
    3. groups=1, width_per_group=64, norm_layer=None):
    4. super(ResNet_Refine, self).__init__()
    5. if norm_layer is None:
    6. norm_layer = nn.BatchNorm2d
    7. self._norm_layer = norm_layer
    8. self.inplanes = 1024
    9. self.dilation = 1
    10. self.is_local = is_local
    11. self.groups = groups
    12. self.base_width = width_per_group
    13. self.layer4 = self._make_layer(block, 512, layer, stride=2)
    14. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    15. for m in self.modules():
    16. if isinstance(m, nn.Conv2d):
    17. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    18. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    19. nn.init.constant_(m.weight, 1)
    20. nn.init.constant_(m.bias, 0)
    21. # Zero-initialize the last BN in each residual branch,
    22. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    23. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    24. if zero_init_residual:
    25. for m in self.modules():
    26. if isinstance(m, Bottleneck):
    27. nn.init.constant_(m.bn3.weight, 0)
    28. elif isinstance(m, BasicBlock):
    29. nn.init.constant_(m.bn2.weight, 0)
    30. def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    31. norm_layer = self._norm_layer
    32. downsample = None
    33. previous_dilation = self.dilation
    34. if dilate:
    35. self.dilation *= stride
    36. stride = 1
    37. if stride != 1 or self.inplanes != planes * block.expansion:
    38. downsample = nn.Sequential(
    39. conv1x1(self.inplanes, planes * block.expansion, stride),
    40. norm_layer(planes * block.expansion),
    41. )
    42. layers = []
    43. layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
    44. self.base_width, previous_dilation, norm_layer))
    45. self.inplanes = planes * block.expansion
    46. for _ in range(1, blocks):
    47. layers.append(block(self.inplanes, planes, groups=self.groups,
    48. base_width=self.base_width, dilation=self.dilation,
    49. norm_layer=norm_layer))
    50. layers.append(ChannelTransformer(planes * block.expansion, max(planes * block.expansion // 64, 16)))
    51. return nn.Sequential(*layers)
    52. def _forward_impl(self, x):
    53. x = self.layer4(x)
    54. pool_x = self.avgpool(x)
    55. pool_x = torch.flatten(pool_x, 1)
    56. if self.is_local:
    57. return x, pool_x
    58. else:
    59. return pool_x
    60. def forward(self, x):
    61. return self._forward_impl(x)

    SEM

    1. class SEM(nn.Module):
    2. def __init__(self, block, layer, att_size=4, num_classes=1000, zero_init_residual=False,
    3. groups=1, width_per_group=64, replace_stride_with_dilation=None,
    4. norm_layer=None):
    5. super(SEM, self).__init__()
    6. if norm_layer is None:
    7. norm_layer = nn.BatchNorm2d
    8. self._norm_layer = norm_layer
    9. self.inplanes = 1024
    10. self.dilation = 1
    11. self.att_size = att_size
    12. if replace_stride_with_dilation is None:
    13. # each element in the tuple indicates if we should replace
    14. # the 2x2 stride with a dilated convolution instead
    15. replace_stride_with_dilation = [False, False, False]
    16. if len(replace_stride_with_dilation) != 3:
    17. raise ValueError("replace_stride_with_dilation should be None "
    18. "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    19. self.groups = groups
    20. self.base_width = width_per_group
    21. self.layer4 = self._make_layer(block, 512, layer, stride=1)
    22. self.feature1 = nn.Sequential(
    23. conv1x1(self.inplanes, 1),
    24. nn.BatchNorm2d(1),
    25. nn.ReLU(inplace=True),
    26. )
    27. self.feature2 = nn.Sequential(
    28. conv1x1(self.inplanes, 1),
    29. nn.BatchNorm2d(1),
    30. nn.ReLU(inplace=True)
    31. )
    32. self.feature3 = nn.Sequential(
    33. conv1x1(self.inplanes, 1),
    34. nn.BatchNorm2d(1),
    35. nn.ReLU(inplace=True)
    36. )
    37. for m in self.modules():
    38. if isinstance(m, nn.Conv2d):
    39. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    40. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    41. nn.init.constant_(m.weight, 1)
    42. nn.init.constant_(m.bias, 0)
    43. # Zero-initialize the last BN in each residual branch,
    44. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    45. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    46. if zero_init_residual:
    47. for m in self.modules():
    48. if isinstance(m, Bottleneck):
    49. nn.init.constant_(m.bn3.weight, 0)
    50. elif isinstance(m, BasicBlock):
    51. nn.init.constant_(m.bn2.weight, 0)
    52. def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    53. norm_layer = self._norm_layer
    54. downsample = None
    55. previous_dilation = self.dilation
    56. att_expansion = 0.25
    57. layers = []
    58. layers.append(block(self.inplanes, int(self.inplanes * att_expansion), stride,
    59. downsample, self.groups, self.base_width, previous_dilation, norm_layer))
    60. for _ in range(1, blocks):
    61. layers.append(nn.Sequential(
    62. conv1x1(self.inplanes, int(self.inplanes * att_expansion)),
    63. nn.BatchNorm2d(int(self.inplanes * att_expansion))
    64. ))
    65. self.inplanes = int(self.inplanes * att_expansion)
    66. layers.append(block(self.inplanes, int(self.inplanes * att_expansion), groups=self.groups,
    67. base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer))
    68. return nn.Sequential(*layers)
    69. def _mask(self, feature, x):
    70. with torch.no_grad():
    71. cam1 = feature.mean(1)
    72. attn = torch.softmax(cam1.view(x.shape[0], x.shape[2] * x.shape[3]), dim=1)#B,H,W
    73. std, mean = torch.std_mean(attn)
    74. attn = (attn - mean) / (std ** 0.3) + 1 #0.15
    75. attn = (attn.view((x.shape[0], 1, x.shape[2], x.shape[3]))).clamp(0, 2)
    76. return attn
    77. def _forward_impl(self, x):
    78. x = self.layer4(x)#bs*64*14*14
    79. fea1 = self.feature1(x) #bs*1*14*14
    80. attn = 2-self._mask(fea1, x)
    81. x = x.mul(attn.repeat(1, self.inplanes, 1, 1))
    82. fea2 = self.feature2(x)
    83. attn = 2-self._mask(fea2, x)
    84. x = x.mul(attn.repeat(1, self.inplanes, 1, 1))
    85. fea3 = self.feature3(x)
    86. x = torch.cat([fea1,fea2,fea3], dim=1)
    87. return x
    88. def forward(self, x):
    89. return self._forward_impl(x)

    ICON

    1. class ChannelTransformer(nn.Module):
    2. def __init__(self, dim, num_heads):
    3. super().__init__()
    4. self.num_heads = num_heads
    5. head_dim = dim // num_heads
    6. self.scale = head_dim ** -0.5
    7. self.head_dim = head_dim
    8. self.norm = nn.BatchNorm2d(dim)
    9. self.relu = nn.ReLU(inplace=True)
    10. self.qkv = nn.Conv2d(dim, dim * 3, 1, groups=num_heads)
    11. self.qkv2 = nn.Conv2d(dim, dim * 3, 1, groups=head_dim)
    12. def forward(self, x):
    13. B, C, H, W = x.shape
    14. qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, H * W).transpose(0, 1)
    15. q, k, v = qkv[0], qkv[1], qkv[2]
    16. attn = (q @ k.transpose(-2, -1)) * self.scale
    17. attn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)
    18. attn = attn.softmax(dim=-1)
    19. x = ((attn @ v).reshape(B, C, H, W) + x).reshape(B, self.num_heads, self.head_dim, H, W).transpose(1, 2).reshape(B, C, H, W)
    20. y = self.norm(x)
    21. x = self.relu(y)
    22. qkv2 = self.qkv2(x).reshape(B, 3, self.head_dim, self.num_heads, H * W).transpose(0, 1)
    23. q, k, v = qkv2[0], qkv2[1], qkv2[2]
    24. attn = (q @ k.transpose(-2, -1)) * (self.num_heads ** -0.5)
    25. attn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)
    26. attn = attn.softmax(dim=-1)
    27. x = (attn @ v).reshape(B, self.head_dim, self.num_heads, H, W).transpose(1, 2).reshape(B, C, H, W) + y
    28. return x

    损失函数

    1. class ADSH_Loss(nn.Module):
    2. def __init__(self, code_length, gamma):
    3. super(ADSH_Loss, self).__init__()
    4. self.code_length = code_length
    5. self.gamma = gamma
    6. def forward(self, F, B, S, omega):
    7. hash_loss = ((self.code_length * S - F @ B.t()) ** 2).sum() / (F.shape[0] * B.shape[0]) / self.code_length * 12
    8. quantization_loss = ((F - B[omega, :]) ** 2).sum() / (F.shape[0] * B.shape[0]) * self.gamma / self.code_length * 12
    9. loss = hash_loss + quantization_loss
    10. return loss, hash_loss, quantization_loss

    4)网络测试

    1. def valid(query_dataloader, train_dataloader, retrieval_dataloader, code_length, args):
    2. num_classes, att_size, feat_size = args.num_classes, 1, 2048
    3. model = SEMICON.semicon(code_length=code_length, num_classes=num_classes, att_size=att_size, feat_size=feat_size,
    4. device=args.device, pretrained=True)
    5. model.to(args.device)
    6. model.load_state_dict(torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/model.pkl'), strict=False)
    7. model.eval()
    8. query_code = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_code.t')
    9. query_code = query_code.to(args.device)
    10. query_dataloader.dataset.get_onehot_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_targets.t')
    11. B = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_code.t')
    12. B = B.to(args.device)
    13. retrieval_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_targets.t')
    14. retrieval_targets = retrieval_targets.to(args.device)
    15. mAP = evaluate.mean_average_precision(
    16. query_code.to(args.device),
    17. B,
    18. query_dataloader.dataset.get_onehot_targets().to(args.device),
    19. retrieval_targets,
    20. args.device,
    21. args.topk,
    22. )
    23. print("Code_Length: " + str(code_length), end="; ")
    24. print('[mAP:{:.5f}]'.format(mAP))

    5)网络结构 (onnx model)

  • 相关阅读:
    web页面之间的3种关系
    分享:互信息在对比学习中的应用
    实体-联系模型--E-R图
    Django+Celery框架自动化定时任务开发
    VisualStudio(VS)设置程序的版本信息(C-C++)
    学习笔记-Power-Linux
    微星迫击炮b660m使用intel arc a750/770显卡功耗优化方法
    Python语言学习:Python语言学习之逻辑控制语句(if语句&for语句&while语句&range语句&with语句)的简介、案例应用之详细攻略
    JVM学习第一天
    【产线故障】线上接口请求过慢如何排查?
  • 原文地址:https://blog.csdn.net/qq_38964360/article/details/126922099