• DDC代码阅读笔记


    论文:《Deep Domain Confusion: Maximizing for Domain Invariance》 

    https://arxiv.org/abs/1412.3474

    [DDC]Deep Domain Confusion: Maximizing for Domain Invariance_HHzdh的博客-CSDN博客

    源码:https://github.com/guoXuehong77/transferlearning/tree/master/code/deep/DDC_DeepCoral

    前言(摘自README.md)

            该源码主要是做了《Deep CORAL Correlation Alignment for Deep Domain Adaptation》的复现工作,但是只要把源码中的CORAL loss替换成MMD loss, 就可以复现DDC了。

    网络部分

            该架构使用了一个自适应层(adaptation layer)和一个基于最大平均偏差(MMD,maximum mean discrepancy)的domain confusion loss来自动学习一个联合训练的表示来优化分类和域不变性。论文是在AlexNet的fc7层后面加入fc_adapt层,源码中实际上实现了多种backbone,包括alexnet、resnet18-152。

    代码部分

    1、backbone.py

            该文件定义了几种经典的CNN网络,包括alexnet、resnet18、resnet34、resnet50、resnet101、resnet152。初始化时加载预训练模型。

    1. import numpy as np
    2. import torch
    3. import torch.nn as nn
    4. import torchvision
    5. from torchvision import models
    6. from torch.autograd import Variable
    7. # convnet without the last layer
    8. class AlexNetFc(nn.Module):
    9. def __init__(self):
    10. super(AlexNetFc, self).__init__()
    11. model_alexnet = models.alexnet(pretrained=True)
    12. self.features = model_alexnet.features
    13. self.classifier = nn.Sequential()
    14. for i in range(6):
    15. self.classifier.add_module(
    16. "classifier"+str(i), model_alexnet.classifier[i])
    17. self.__in_features = model_alexnet.classifier[6].in_features
    18. def forward(self, x):
    19. x = self.features(x)
    20. x = x.view(x.size(0), 256*6*6)
    21. x = self.classifier(x)
    22. return x
    23. def output_num(self):
    24. return self.__in_features
    25. class ResNet18Fc(nn.Module):
    26. def __init__(self):
    27. super(ResNet18Fc, self).__init__()
    28. model_resnet18 = models.resnet18(pretrained=True)
    29. self.conv1 = model_resnet18.conv1
    30. self.bn1 = model_resnet18.bn1
    31. self.relu = model_resnet18.relu
    32. self.maxpool = model_resnet18.maxpool
    33. self.layer1 = model_resnet18.layer1
    34. self.layer2 = model_resnet18.layer2
    35. self.layer3 = model_resnet18.layer3
    36. self.layer4 = model_resnet18.layer4
    37. self.avgpool = model_resnet18.avgpool
    38. self.__in_features = model_resnet18.fc.in_features
    39. def forward(self, x):
    40. x = self.conv1(x)
    41. x = self.bn1(x)
    42. x = self.relu(x)
    43. x = self.maxpool(x)
    44. x = self.layer1(x)
    45. x = self.layer2(x)
    46. x = self.layer3(x)
    47. x = self.layer4(x)
    48. x = self.avgpool(x)
    49. x = x.view(x.size(0), -1)
    50. return x
    51. def output_num(self):
    52. return self.__in_features
    53. class ResNet34Fc(nn.Module):
    54. def __init__(self):
    55. super(ResNet34Fc, self).__init__()
    56. model_resnet34 = models.resnet34(pretrained=True)
    57. self.conv1 = model_resnet34.conv1
    58. self.bn1 = model_resnet34.bn1
    59. self.relu = model_resnet34.relu
    60. self.maxpool = model_resnet34.maxpool
    61. self.layer1 = model_resnet34.layer1
    62. self.layer2 = model_resnet34.layer2
    63. self.layer3 = model_resnet34.layer3
    64. self.layer4 = model_resnet34.layer4
    65. self.avgpool = model_resnet34.avgpool
    66. self.__in_features = model_resnet34.fc.in_features
    67. def forward(self, x):
    68. x = self.conv1(x)
    69. x = self.bn1(x)
    70. x = self.relu(x)
    71. x = self.maxpool(x)
    72. x = self.layer1(x)
    73. x = self.layer2(x)
    74. x = self.layer3(x)
    75. x = self.layer4(x)
    76. x = self.avgpool(x)
    77. x = x.view(x.size(0), -1)
    78. return x
    79. def output_num(self):
    80. return self.__in_features
    81. class ResNet50Fc(nn.Module):
    82. def __init__(self):
    83. super(ResNet50Fc, self).__init__()
    84. model_resnet50 = models.resnet50(pretrained=True)
    85. self.conv1 = model_resnet50.conv1
    86. self.bn1 = model_resnet50.bn1
    87. self.relu = model_resnet50.relu
    88. self.maxpool = model_resnet50.maxpool
    89. self.layer1 = model_resnet50.layer1
    90. self.layer2 = model_resnet50.layer2
    91. self.layer3 = model_resnet50.layer3
    92. self.layer4 = model_resnet50.layer4
    93. self.avgpool = model_resnet50.avgpool
    94. self.__in_features = model_resnet50.fc.in_features
    95. def forward(self, x):
    96. x = self.conv1(x)
    97. x = self.bn1(x)
    98. x = self.relu(x)
    99. x = self.maxpool(x)
    100. x = self.layer1(x)
    101. x = self.layer2(x)
    102. x = self.layer3(x)
    103. x = self.layer4(x)
    104. x = self.avgpool(x)
    105. x = x.view(x.size(0), -1)
    106. return x
    107. def output_num(self):
    108. return self.__in_features
    109. class ResNet101Fc(nn.Module):
    110. def __init__(self):
    111. super(ResNet101Fc, self).__init__()
    112. model_resnet101 = models.resnet101(pretrained=True)
    113. self.conv1 = model_resnet101.conv1
    114. self.bn1 = model_resnet101.bn1
    115. self.relu = model_resnet101.relu
    116. self.maxpool = model_resnet101.maxpool
    117. self.layer1 = model_resnet101.layer1
    118. self.layer2 = model_resnet101.layer2
    119. self.layer3 = model_resnet101.layer3
    120. self.layer4 = model_resnet101.layer4
    121. self.avgpool = model_resnet101.avgpool
    122. self.__in_features = model_resnet101.fc.in_features
    123. def forward(self, x):
    124. x = self.conv1(x)
    125. x = self.bn1(x)
    126. x = self.relu(x)
    127. x = self.maxpool(x)
    128. x = self.layer1(x)
    129. x = self.layer2(x)
    130. x = self.layer3(x)
    131. x = self.layer4(x)
    132. x = self.avgpool(x)
    133. x = x.view(x.size(0), -1)
    134. return x
    135. def output_num(self):
    136. return self.__in_features
    137. class ResNet152Fc(nn.Module):
    138. def __init__(self):
    139. super(ResNet152Fc, self).__init__()
    140. model_resnet152 = models.resnet152(pretrained=True)
    141. self.conv1 = model_resnet152.conv1
    142. self.bn1 = model_resnet152.bn1
    143. self.relu = model_resnet152.relu
    144. self.maxpool = model_resnet152.maxpool
    145. self.layer1 = model_resnet152.layer1
    146. self.layer2 = model_resnet152.layer2
    147. self.layer3 = model_resnet152.layer3
    148. self.layer4 = model_resnet152.layer4
    149. self.avgpool = model_resnet152.avgpool
    150. self.__in_features = model_resnet152.fc.in_features
    151. def forward(self, x):
    152. x = self.conv1(x)
    153. x = self.bn1(x)
    154. x = self.relu(x)
    155. x = self.maxpool(x)
    156. x = self.layer1(x)
    157. x = self.layer2(x)
    158. x = self.layer3(x)
    159. x = self.layer4(x)
    160. x = self.avgpool(x)
    161. x = x.view(x.size(0), -1)
    162. return x
    163. def output_num(self):
    164. return self.__in_features
    165. network_dict = {"alexnet": AlexNetFc,
    166. "resnet18": ResNet18Fc,
    167. "resnet34": ResNet34Fc,
    168. "resnet50": ResNet50Fc,
    169. "resnet101": ResNet101Fc,
    170. "resnet152": ResNet152Fc}

    2、data_loader.py

            数据加载部分,对于训练集数据Resize成[256,256]大小再进行[224,224]的随机裁剪,后续操作以此为随机水平翻转、转为tensor、标准化;而对于测试集数据则直接Resize为[224,224]大小,后进行tensor化和标准化。

    1. from torchvision import datasets, transforms
    2. import torch
    3. def load_data(data_folder, batch_size, train, kwargs):
    4. transform = {
    5. 'train': transforms.Compose(
    6. [transforms.Resize([256, 256]),
    7. transforms.RandomCrop(224),
    8. transforms.RandomHorizontalFlip(),
    9. transforms.ToTensor(),
    10. transforms.Normalize(mean=[0.485, 0.456, 0.406],
    11. std=[0.229, 0.224, 0.225])]),
    12. 'test': transforms.Compose(
    13. [transforms.Resize([224, 224]),
    14. transforms.ToTensor(),
    15. transforms.Normalize(mean=[0.485, 0.456, 0.406],
    16. std=[0.229, 0.224, 0.225])])
    17. }
    18. data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
    19. data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
    20. return data_loader

    3、mmd.py

            即论文提出的MMD loss,MMD(最大均值差异)是迁移学习,尤其是Domain adaptation (域适应)中使用最广泛(目前)的一种损失函数,主要用来度量两个不同但相关的分布的距离。两个分布的距离定义为:可以参考以下这篇博客。【代码阅读】最大均值差异(Maximum Mean Discrepancy, MMD)损失函数代码解读(Pytroch版)_Vincent_gc的博客-CSDN博客_mmd损失

    1. import torch
    2. import torch.nn as nn
    3. class MMD_loss(nn.Module):
    4. def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
    5. '''
    6. source:源域数据(n * len(x))
    7. target:目标域数据(m * len(y))
    8. kernel_mul:核的倍数
    9. kernel_num:多少个核
    10. fix_sigma: 不同高斯核的sigma值
    11. '''
    12. super(MMD_loss, self).__init__()
    13. self.kernel_num = kernel_num
    14. self.kernel_mul = kernel_mul
    15. self.fix_sigma = None
    16. self.kernel_type = kernel_type
    17. def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    18. n_samples = int(source.size()[0]) + int(target.size()[0]) # 求矩阵的行数,一般source和target的尺度是一样的,这样便于计算
    19. total = torch.cat([source, target], dim=0) # 将source,target按列方向合并
    20. # 将total复制(n+m)份
    21. total0 = total.unsqueeze(0).expand(
    22. int(total.size(0)), int(total.size(0)), int(total.size(1)))
    23. # 将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份
    24. total1 = total.unsqueeze(1).expand(
    25. int(total.size(0)), int(total.size(0)), int(total.size(1)))
    26. # 求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0)
    27. L2_distance = ((total0-total1)**2).sum(2)
    28. # 调整高斯核函数的sigma值
    29. if fix_sigma:
    30. bandwidth = fix_sigma
    31. else:
    32. bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    33. # 以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4]
    34. bandwidth /= kernel_mul ** (kernel_num // 2)
    35. bandwidth_list = [bandwidth * (kernel_mul**i)
    36. for i in range(kernel_num)]
    37. # 高斯核函数的数学表达式
    38. kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
    39. for bandwidth_temp in bandwidth_list]
    40. # 得到最终的核矩阵
    41. return sum(kernel_val)
    42. def linear_mmd2(self, f_of_X, f_of_Y):
    43. loss = 0.0
    44. delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
    45. loss = delta.dot(delta.T)
    46. return loss
    47. def forward(self, source, target):
    48. if self.kernel_type == 'linear':
    49. return self.linear_mmd2(source, target)
    50. elif self.kernel_type == 'rbf':
    51. batch_size = int(source.size()[0]) # 一般默认为源域和目标域的batchsize相同
    52. kernels = self.guassian_kernel(
    53. source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
    54. # 将核矩阵分成4部分
    55. with torch.no_grad():
    56. XX = torch.mean(kernels[:batch_size, :batch_size])
    57. YY = torch.mean(kernels[batch_size:, batch_size:])
    58. XY = torch.mean(kernels[:batch_size, batch_size:])
    59. YX = torch.mean(kernels[batch_size:, :batch_size])
    60. loss = torch.mean(XX + YY - XY - YX) # 因为一般都是n==m,所以L矩阵一般不加入计算
    61. torch.cuda.empty_cache()
    62. return loss

     4、models.py

    1. import torch.nn as nn
    2. from Coral import CORAL
    3. import mmd
    4. import backbone
    5. class Transfer_Net(nn.Module):
    6. def __init__(self, num_class, base_net='resnet50', transfer_loss='mmd', use_bottleneck=True, bottleneck_width=256, width=1024):
    7. super(Transfer_Net, self).__init__()
    8. self.base_network = backbone.network_dict[base_net]()
    9. self.use_bottleneck = use_bottleneck
    10. self.transfer_loss = transfer_loss
    11. bottleneck_list = [nn.Linear(self.base_network.output_num(
    12. ), bottleneck_width), nn.BatchNorm1d(bottleneck_width), nn.ReLU(), nn.Dropout(0.5)]
    13. self.bottleneck_layer = nn.Sequential(*bottleneck_list)
    14. classifier_layer_list = [nn.Linear(self.base_network.output_num(), width), nn.ReLU(), nn.Dropout(0.5),
    15. nn.Linear(width, num_class)]
    16. self.classifier_layer = nn.Sequential(*classifier_layer_list)
    17. self.bottleneck_layer[0].weight.data.normal_(0, 0.005)
    18. self.bottleneck_layer[0].bias.data.fill_(0.1)
    19. for i in range(2):
    20. self.classifier_layer[i * 3].weight.data.normal_(0, 0.01)
    21. self.classifier_layer[i * 3].bias.data.fill_(0.0)
    22. def forward(self, source, target):
    23. source = self.base_network(source)
    24. target = self.base_network(target)
    25. source_clf = self.classifier_layer(source)
    26. if self.use_bottleneck:
    27. source = self.bottleneck_layer(source)
    28. target = self.bottleneck_layer(target)
    29. transfer_loss = self.adapt_loss(source, target, self.transfer_loss)
    30. return source_clf, transfer_loss
    31. def predict(self, x):
    32. features = self.base_network(x)
    33. clf = self.classifier_layer(features)
    34. return clf
    35. def adapt_loss(self, X, Y, adapt_loss):
    36. """Compute adaptation loss, currently we support mmd and coral
    37. Arguments:
    38. X {tensor} -- source matrix
    39. Y {tensor} -- target matrix
    40. adapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own loss
    41. Returns:
    42. [tensor] -- adaptation loss tensor
    43. """
    44. if adapt_loss == 'mmd':
    45. mmd_loss = mmd.MMD_loss()
    46. loss = mmd_loss(X, Y)
    47. elif adapt_loss == 'coral':
    48. loss = CORAL(X, Y)
    49. else:
    50. loss = 0
    51. return loss

             以AlexNet为例,model.py中定义的base_network即直接调用backbone.py中的alexnet,根据网络结构图,下图即conv1-5以及fc6-7的结构,因为这部分源域和目标域均使用且参数共享,故作为base_network。

            而bottleneck_layer 即论文中提到的fc_adapt层,结构如下所示。

     5、main.py

    (1)参数设置

    1. # Command setting
    2. parser = argparse.ArgumentParser(description='DDC_DCORAL')
    3. parser.add_argument('--model', type=str, default='alexnet')
    4. parser.add_argument('--batchsize', type=int, default=2)
    5. parser.add_argument('--src', type=str, default='amazon')
    6. parser.add_argument('--tar', type=str, default='webcam')
    7. parser.add_argument('--n_class', type=int, default=31)
    8. parser.add_argument('--lr', type=float, default=1e-3)
    9. parser.add_argument('--n_epoch', type=int, default=100)
    10. parser.add_argument('--momentum', type=float, default=0.9)
    11. parser.add_argument('--decay', type=float, default=5e-4)
    12. parser.add_argument('--data', type=str, default='E:\Myself\Office31\Original_images')
    13. parser.add_argument('--early_stop', type=int, default=20)
    14. parser.add_argument('--lamb', type=float, default=10)
    15. parser.add_argument('--trans_loss', type=str, default='mmd')
    16. args = parser.parse_args()

     (2)test函数

    Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别_BEINTHEMEMENT的博客-CSDN博客_def train(model):

            model.eval()的作用是不启用 Batch Normalization 和 Dropout。如果模型中有BN层(Batch Normalization)和Dropout。在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

    1. def test(model, target_test_loader):
    2. model.eval()
    3. test_loss = utils.AverageMeter()
    4. correct = 0
    5. criterion = torch.nn.CrossEntropyLoss()
    6. len_target_dataset = len(target_test_loader.dataset)
    7. with torch.no_grad():
    8. for data, target in target_test_loader:
    9. data, target = data.to(DEVICE), target.to(DEVICE)
    10. s_output = model.predict(data)
    11. loss = criterion(s_output, target)
    12. test_loss.update(loss.item())
    13. pred = torch.max(s_output, 1)[1]
    14. correct += torch.sum(pred == target)
    15. acc = 100. * correct / len_target_dataset
    16. return acc

    (3)train函数

            model.train()的作用是启用 Batch Normalization 和 Dropout。如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

    1. def train(source_loader, target_train_loader, target_test_loader, model, optimizer):
    2. len_source_loader = len(source_loader)
    3. len_target_loader = len(target_train_loader)
    4. best_acc = 0
    5. stop = 0
    6. for e in range(args.n_epoch):
    7. stop += 1
    8. train_loss_clf = utils.AverageMeter()
    9. train_loss_transfer = utils.AverageMeter()
    10. train_loss_total = utils.AverageMeter()
    11. model.train()
    12. iter_source, iter_target = iter(source_loader), iter(target_train_loader)
    13. n_batch = min(len_source_loader, len_target_loader)
    14. criterion = torch.nn.CrossEntropyLoss()
    15. for _ in range(n_batch):
    16. data_source, label_source = iter_source.next()
    17. data_target, _ = iter_target.next()
    18. data_source, label_source = data_source.to(
    19. DEVICE), label_source.to(DEVICE)
    20. data_target = data_target.to(DEVICE)
    21. optimizer.zero_grad()
    22. label_source_pred, transfer_loss = model(data_source, data_target)
    23. clf_loss = criterion(label_source_pred, label_source)
    24. loss = clf_loss + args.lamb * transfer_loss
    25. loss.backward()
    26. optimizer.step()
    27. train_loss_clf.update(clf_loss.item())
    28. train_loss_transfer.update(transfer_loss.item())
    29. train_loss_total.update(loss.item())
    30. # Test
    31. acc = test(model, target_test_loader)
    32. log.append([train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg])
    33. np_log = np.array(log, dtype=float)
    34. np.savetxt('train_log.csv', np_log, delimiter=',', fmt='%.6f')
    35. print('Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, total_Loss: {:.4f}, acc: {:.4f}'.format(
    36. e, args.n_epoch, train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg, acc))
    37. if best_acc < acc:
    38. best_acc = acc
    39. stop = 0
    40. if stop >= args.early_stop:
    41. break
    42. print('Transfer result: {:.4f}'.format(best_acc))

    (4)数据加载和main

    1. # 方便自己电脑运行,num_workers更改为1,原代码为4
    2. def load_data(src, tar, root_dir):
    3. folder_src = os.path.join(root_dir, src)
    4. folder_tar = os.path.join(root_dir, tar)
    5. source_loader = data_loader.load_data(
    6. folder_src, args.batchsize, True, {'num_workers': 1})
    7. target_train_loader = data_loader.load_data(
    8. folder_tar, args.batchsize, True, {'num_workers': 1})
    9. target_test_loader = data_loader.load_data(
    10. folder_tar, args.batchsize, False, {'num_workers': 1})
    11. return source_loader, target_train_loader, target_test_loader
    12. if __name__ == '__main__':
    13. torch.manual_seed(0)
    14. source_name = "amazon"
    15. target_name = "webcam"
    16. print('Src: %s, Tar: %s' % (source_name, target_name))
    17. print('Backbone: %s,Loss: %s' % (args.model, args.trans_loss))
    18. source_loader, target_train_loader, target_test_loader = load_data(
    19. source_name, target_name, args.data)
    20. model = models.Transfer_Net(
    21. args.n_class, transfer_loss=args.trans_loss, base_net=args.model).to(DEVICE)
    22. print(model)
    23. optimizer = torch.optim.SGD([
    24. {'params': model.base_network.parameters()},
    25. {'params': model.bottleneck_layer.parameters(), 'lr': 10 * args.lr},
    26. {'params': model.classifier_layer.parameters(), 'lr': 10 * args.lr},
    27. ], lr=args.lr, momentum=args.momentum, weight_decay=args.decay)
    28. train(source_loader, target_train_loader,
    29. target_test_loader, model, optimizer)

    实验部分

            debugging~虽然跑通了但是跟论文和github给出的结果相差有点大??

  • 相关阅读:
    从谷歌CRE谈起,运维如何培养服务意识?
    虚数是什么
    一文浅入Springboot+mybatis-plus+actuator+Prometheus+Grafana+Swagger2.9.2开发运维一体化
    【WINDOWS / DOS 批处理】call命令的变量延迟展开特性
    【学习笔记】4、组合逻辑电路(下)
    CocosCreator 3.x热更新学习
    画一个 “月饼” 陪我过中秋,玩转炫彩 “月饼” 之 基本测试
    2022前端面试题上岸手册-Vue部分
    java中泛型(一)
    拓端tecdat|Python之LDA主题模型算法应用
  • 原文地址:https://blog.csdn.net/weixin_44855366/article/details/126704430