• 使用CIFAR10数据集完成知识蒸馏(参照了快乐就好_Peng博主的博客)致谢


           俺这次做的是知识蒸馏,但是看到别人做的数据集不是CIFAT10,所以有点手痒,就自己做了一期CIFAR10数据集的尝试,做出来的结果........还行,还能感觉到教师网络确实是具有指导作用的

    在这里我推荐搭建去B站看一下:同济子豪兄讲解的知识蒸馏,挺详细的

    话不多说,直接进入正题吧:

    1.先进行一个教师网络的定义,并且这个网络的隐藏层的参数我设置为1200,这个主要是为了凸显教师网络和学生网络的参数量大小之间具有巨大的差距。

    注意:在运行我的代码的时候可能会报错关于:if __name__ == "__main__": 问题,这个百度就行了,这里就不详细介绍了,因为我为了调用教师网络去到蒸馏网络那里运行,我注释了if __name__ == "__main__"语句,并且把缩径也关掉了,但是你直接使用教师网络运行的话会报错的可能。。。。。。。。。。

    1. # -*- coding:utf-8 -*-
    2. # @Time : 2022-06-27 13:43
    3. # @Author : DaFuChen
    4. # @File : teacher_moudle.py
    5. # @software: PyCharm
    6. # if __name__ == "__main__":
    7. import torch
    8. import torch.nn as nn
    9. import torchvision
    10. import torchvision.transforms as transforms
    11. from tqdm import tqdm # tqdm 是python中很常用的模块,它的作用就是在终端上出现一个进度条,使得代码进度可视化.
    12. file_name_path = "test_teacher" + ".txt"
    13. file_name_path = open(file_name_path, "w")
    14. """
    15. 第一部分
    16. 开始进行数据集合的引入
    17. """
    18. # 设置一个随机种子,为了实验的复现
    19. torch.manual_seed(9)
    20. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    21. # 使用cudnn加速卷积运算
    22. torch.backends.cudnn.benchmark = True
    23. from torchvision.transforms.transforms import ToTensor
    24. # 载入训练集
    25. trainset = torchvision.datasets.CIFAR10(
    26. root='./datac10',
    27. train=True,
    28. transform=ToTensor(), # 将PLIImage转化为适合pytorch进行处理的数据格式
    29. download=True
    30. )
    31. # 生成测试集合数据
    32. testset = torchvision.datasets.CIFAR10(
    33. root='./datac10',
    34. train=False,
    35. transform=transforms.ToTensor(),
    36. download=True
    37. )
    38. tranloader = torch.utils.data.DataLoader(
    39. dataset=trainset,
    40. batch_size=32,
    41. shuffle=True
    42. )
    43. testloader = torch.utils.data.DataLoader(
    44. dataset=testset,
    45. batch_size=32,
    46. shuffle=False
    47. )
    48. """
    49. 第二部分
    50. 定义teacher模型
    51. """
    52. class teacherMoudle(nn.Module):
    53. def __init__(self, in_channels=1, num_class=10):
    54. super(teacherMoudle, self).__init__()
    55. # 定义卷积
    56. self.conv1 = nn.Conv2d(3, 6, 5)
    57. self.conv2 = nn.Conv2d(6, 16, 5)
    58. # 定义池化层
    59. self.pool = nn.MaxPool2d(2, 2)
    60. # 定义激活函数选择relu()
    61. self.relu = nn.ReLU()
    62. # 定义droput层
    63. self.droput = nn.Dropout(p=0.5)
    64. # 定义全连接层
    65. self.fc1 = nn.Linear(16 * 5 * 5, 1200)
    66. # 这里需要注意的是教师的模型的隐藏层的参数很对,有1200个
    67. self.fc2 = nn.Linear(1200, 600)
    68. self.fc3 = nn.Linear(600, num_class)
    69. def forward(self, x):
    70. x = self.pool(self.relu(self.conv1(x)))
    71. x = self.pool(self.relu(self.conv2(x)))
    72. x = x.view(-1, 16 * 5 * 5)
    73. x = self.fc1(x)
    74. x = self.droput(x)
    75. x = self.relu(x)
    76. x = self.fc2(x)
    77. x = self.droput(x)
    78. x = self.relu(x)
    79. x = self.fc3(x)
    80. return x
    81. """
    82. 第三部分
    83. 训练教师模型
    84. """
    85. model = teacherMoudle()
    86. model = model.to(device)
    87. # 定义损失函数和优化器
    88. crossLoss = nn.CrossEntropyLoss()
    89. optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    90. epochs = 6
    91. for epoch in range(epochs):
    92. model.train()
    93. # model.train()的作用是启用
    94. # Batch Normalization和Dropout。如果模型中有BN层或Dropout层, model.train()
    95. # 是保证训练时BN层能够用到每一批数据的均值和方差, 对于Dropout, model.train()
    96. # 是随机取一部分网络连接来训练更新参数。”
    97. for image, lables in tqdm(tranloader):
    98. image = image.to(device)
    99. lables = lables.to(device)
    100. running_loss = 0.0
    101. # 前向预测
    102. pred = model(image)
    103. loss = crossLoss(pred, lables)
    104. running_loss += loss.item()
    105. # 步骤
    106. # 1.优化器清零
    107. # 2.损失返回
    108. # 3.优化器更新参数值
    109. # 反向传播
    110. optimizer.zero_grad() # 必须进行的一个清零操作
    111. # ##################################重要的地方 需要利用loss进行反向传播进行更新参数,不然就会没有优化这个模型一样
    112. loss.backward()
    113. optimizer.step() # 整体参数的保存
    114. """
    115. 第四部分
    116. 模型的评估
    117. """
    118. # 在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。
    119. model.eval()
    120. num_correct = 0
    121. num_samples = 0
    122. with torch.no_grad():
    123. for image, lables in testloader:
    124. image = image.to(device)
    125. lables = lables.to(device)
    126. pred = model(image)
    127. predictions = pred.max(1).indices # 获取标签值
    128. num_correct += (predictions == lables).sum()
    129. num_samples += predictions.size(0)
    130. acc = (num_correct / num_samples).item()
    131. model.train()
    132. print('Epoch:{}\t Accuracy:{:.5f}'.format(epoch + 1, acc))
    133. print('loss : ' + str(running_loss))
    134. txt = open("./test_teacher.txt", "a").write((str(epoch+1) + " " + str(acc) + " " + str(running_loss) + '\r\n'))

    训练的结果是:

    1    0.49399998784065247    1.1463124752044678
    
    2    0.5374000072479248    1.0659717321395874
    
    3    0.5343999862670898    1.388565182685852
    
    4    0.5910999774932861    1.176950454711914
    
    5    0.6060000061988831    1.2053321599960327
    
    6    0.6089999675750732    1.2049661874771118

    接下来就是学生网络了,我在隐藏层那里定义了相对于教师网络来说隐藏层还是很少的神经元,所以看到了学习的效果果真不如教师网络......

    1. # -*- coding:utf-8 -*-
    2. # @Time : 2022-06-27 13:44
    3. # @Author : DaFuChen
    4. # @File : student_moudle.py
    5. # @software: PyCharm
    6. import torchvision.datasets
    7. if __name__ == "__main__":
    8. import torch
    9. import torch.nn as nn
    10. from tqdm import tqdm
    11. from torchvision.transforms.transforms import ToTensor
    12. file_name_path = "test_student" + ".txt"
    13. file_name_path = (file_name_path, "w")
    14. torch.manual_seed(8)
    15. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    16. torch.backends.cudnn.benchmark = True
    17. train_set = torchvision.datasets.CIFAR10(
    18. root="./datac10",
    19. train=True,
    20. transform=ToTensor(),
    21. download=True
    22. )
    23. test_set = torchvision.datasets.CIFAR10(
    24. root="./datac10",
    25. train=False,
    26. transform=ToTensor(),
    27. download=True
    28. )
    29. train_loader = torch.utils.data.DataLoader(
    30. dataset=train_set,
    31. batch_size=32,
    32. shuffle=True
    33. )
    34. test_loader = torch.utils.data.DataLoader(
    35. dataset=test_set,
    36. batch_size=32,
    37. shuffle=True
    38. )
    39. class student_moudle(nn.Module):
    40. def __init__(self, in_channels=1, num_class=10):
    41. super(student_moudle, self).__init__()
    42. self.conv1 = nn.Conv2d(3, 6, 5)
    43. self.conv2 = nn.Conv2d(6, 16, 5)
    44. self.pool = nn.MaxPool2d(2, 2)
    45. self.relu = nn.ReLU()
    46. self.droput = nn.Dropout(p=0.5)
    47. # 定义全连接层
    48. self.fc1 = nn.Linear(16 * 5 * 5, 64)
    49. # 学生的参数是64层,比教师网络的层数少很多,可以使用
    50. self.fc2 = nn.Linear(64, 32)
    51. self.fc3 = nn.Linear(32, num_class)
    52. def forward(self, x):
    53. x = self.pool(self.relu(self.conv1(x)))
    54. x = self.pool(self.relu(self.conv2(x)))
    55. x = x.view(-1, 16 * 5 * 5)
    56. x = self.fc1(x)
    57. x = self.droput(x)
    58. x = self.relu(x)
    59. x = self.fc2(x)
    60. x = self.droput(x)
    61. x = self.relu(x)
    62. x = self.fc3(x)
    63. return x
    64. model = student_moudle()
    65. model = model.to(device)
    66. crossLoss = nn.CrossEntropyLoss()
    67. optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    68. epochs = 6
    69. for epoch in range(epochs):
    70. model.train()
    71. for image, lables in tqdm(train_loader):
    72. image = image.to(device)
    73. lables = lables.to(device)
    74. running_loss = 0.0
    75. pred = model(image)
    76. loss = crossLoss(pred, lables)
    77. running_loss += loss.item()
    78. optimizer.zero_grad()
    79. loss.backward()
    80. optimizer.step()
    81. model.eval()
    82. num_correct = 0
    83. num_samples = 0
    84. with torch.no_grad():
    85. for image, lables in test_loader:
    86. image = image.to(device)
    87. lables = lables.to(device)
    88. pred = model(image)
    89. predictions = pred.max(1).indices
    90. num_correct += (predictions == lables).sum()
    91. num_samples +=predictions.size(0)
    92. acc = (num_correct / num_samples).item()
    93. model.train()
    94. print('Epoch:{}\t Accuracy:{:.5f}'.format(epoch + 1, acc))
    95. print('loss : ' + str(running_loss))
    96. txt = open("./test_student.txt", "a").write(
    97. (str(epoch + 1) + " " + str(acc) + " " + str(running_loss) + '\r\n'))

    训练的效果如下:

    1    0.3483999967575073    1.81593656539917
    
    2    0.4001999795436859    1.9031776189804077
    
    3    0.43629997968673706    1.633713960647583
    
    4    0.4610999822616577    1.5931921005249023
    
    5    0.4650999903678894    1.601379632949829
    
    6    0.47529998421669006    1.547680377960205

    接下来就是蒸馏网络了:

    1. # -*- coding:utf-8 -*-
    2. # @Time : 2022-06-30 18:49
    3. # @Author : DaFuChen
    4. # @File : distillation.py
    5. # @software: PyCharm
    6. import torch
    7. import torch.nn as nn
    8. import torch.nn.functional as F
    9. import torchvision
    10. from torchvision.transforms import ToTensor
    11. from tqdm import tqdm
    12. import numpy as np
    13. # 这里引用不了的一个问题就是------我在student_moudle.py文件那里使用了 if __name__ == '__main__'的语句,导致缩径造成引入失败
    14. # import student_moudle
    15. import teacher_moudle
    16. if __name__ == "__main__":
    17. file_name_path = "distillation" + ".txt"
    18. file_name_path = (file_name_path, "w")
    19. torch.manual_seed(8)
    20. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    21. torch.backends.cudnn.benchmark = True
    22. train_set = torchvision.datasets.CIFAR10(
    23. root="./datac10",
    24. train=True,
    25. transform=ToTensor(),
    26. download=True
    27. )
    28. test_set = torchvision.datasets.CIFAR10(
    29. root="./datac10",
    30. train=False,
    31. transform=ToTensor(),
    32. download=True
    33. )
    34. train_loader = torch.utils.data.DataLoader(
    35. dataset=train_set,
    36. batch_size=32,
    37. shuffle=True
    38. )
    39. test_loader = torch.utils.data.DataLoader(
    40. dataset=test_set,
    41. batch_size=32,
    42. shuffle=True
    43. )
    44. # 把一个学生模型在这里进行定义,因为写在其他的py文件的话,会在创建student_model的时候产生训练学生网络的现象
    45. # 在创建teacher的时候就发生了这样的事情
    46. class student_moudle(nn.Module):
    47. def __init__(self, in_channels=1, num_class=10):
    48. super(student_moudle, self).__init__()
    49. self.conv1 = nn.Conv2d(3, 6, 5)
    50. self.conv2 = nn.Conv2d(6, 16, 5)
    51. self.pool = nn.MaxPool2d(2, 2)
    52. self.relu = nn.ReLU()
    53. self.droput = nn.Dropout(p=0.5)
    54. # 定义全连接层
    55. self.fc1 = nn.Linear(16 * 5 * 5, 64)
    56. # 学生的参数是64层,比教师网络的层数少很多,可以使用
    57. self.fc2 = nn.Linear(64, 32)
    58. self.fc3 = nn.Linear(32, num_class)
    59. def forward(self, x):
    60. x = self.pool(self.relu(self.conv1(x)))
    61. x = self.pool(self.relu(self.conv2(x)))
    62. x = x.view(-1, 16 * 5 * 5)
    63. x = self.fc1(x)
    64. x = self.droput(x)
    65. x = self.relu(x)
    66. x = self.fc2(x)
    67. x = self.droput(x)
    68. x = self.relu(x)
    69. x = self.fc3(x)
    70. return x
    71. # 在这里继续定义个teacher_model网络
    72. class teacherMoudle(nn.Module):
    73. def __init__(self, in_channels=1, num_class=10):
    74. super(teacherMoudle, self).__init__()
    75. # 定义卷积
    76. self.conv1 = nn.Conv2d(3, 6, 5)
    77. self.conv2 = nn.Conv2d(6, 16, 5)
    78. # 定义池化层
    79. self.pool = nn.MaxPool2d(2, 2)
    80. # 定义激活函数选择relu()
    81. self.relu = nn.ReLU()
    82. # 定义droput层
    83. self.droput = nn.Dropout(p=0.5)
    84. # 定义全连接层
    85. self.fc1 = nn.Linear(16 * 5 * 5, 1200)
    86. # 这里需要注意的是教师的模型的隐藏层的参数很对,有1200个
    87. self.fc2 = nn.Linear(1200, 600)
    88. self.fc3 = nn.Linear(600, num_class)
    89. def forward(self, x):
    90. x = self.pool(self.relu(self.conv1(x)))
    91. x = self.pool(self.relu(self.conv2(x)))
    92. x = x.view(-1, 16 * 5 * 5)
    93. x = self.fc1(x)
    94. x = self.droput(x)
    95. x = self.relu(x)
    96. x = self.fc2(x)
    97. x = self.droput(x)
    98. x = self.relu(x)
    99. x = self.fc3(x)
    100. return x
    101. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    102. # 先准备已经接受训练的教师模型
    103. teacher_moudle.model.eval()
    104. # 准备新的学生模型
    105. # model = student_moudle.student_moudle()
    106. model = student_moudle()
    107. model = model.to(device)
    108. # 教师网络需要进行这样一个操作在后面才能使用这个网络进行调用预测image, 因为后面image把它置与GPU中,这里需要把模型也放到GPU中
    109. teachermoudle = teacherMoudle()
    110. teachermoudle = teachermoudle.to(device)
    111. model.train()
    112. # 设置需要进行蒸馏的温度
    113. temp = 7
    114. # 设置蒸馏学习的损失函数 使用的还是交叉熵损失函数
    115. hard_loss = nn.CrossEntropyLoss()
    116. # 设置学习的损失值的权重
    117. alpha = 0.3
    118. # 使用一个soft_loss
    119. soft_loss = nn.KLDivLoss(reduction='batchmean')
    120. # 设置一个优化器
    121. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    122. epochs = 6
    123. for epoch in range(epochs):
    124. # 训练集上训练模型的权重
    125. for image, lables in tqdm(train_loader):
    126. running_loss = 0
    127. image = image.to(device)
    128. lables = lables.to(device)
    129. with torch.no_grad():
    130. teacher_preds = teachermoudle(image)
    131. # print(teacher_preds)
    132. # teacher_preds = teacher_preds.numpy()
    133. # 学生的模型
    134. student_preds = model(image)
    135. # 计算hard_loss
    136. student_loss = hard_loss(student_preds, lables)
    137. # 计算蒸馏后的预测结果soft_loss
    138. ditillation_loss = soft_loss(
    139. F.softmax(student_preds/temp, dim=1),
    140. F.softmax(teacher_preds/temp, dim=1)
    141. )
    142. # 将hard_loss和soft_loss加权求和 更新loss值
    143. loss = alpha*student_loss+(1-alpha)*ditillation_loss
    144. running_loss += loss.item()
    145. # 反向传播 优化权重
    146. optimizer.zero_grad()
    147. loss.backward()
    148. optimizer.step()
    149. # 测试集上评估模型性能
    150. model.eval()
    151. # model.eval()
    152. num_correct = 0
    153. num_samples = 0
    154. with torch.no_grad():
    155. for image2, lables2 in test_loader:
    156. image2 = image2.to(device)
    157. lables2 = lables2.to(device)
    158. preds = model(image2)
    159. predictions = preds.max(1).indices
    160. num_correct += (predictions == lables2).sum()
    161. num_samples += predictions.size(0)
    162. acc = (num_correct/num_samples).item()
    163. model.train()
    164. print('Epoch:{}\t Accuracy:{:.5f}'.format(epoch + 1, acc))
    165. print('loss : ' + str(running_loss))
    166. txt = open("./distillation.txt", "a").write(
    167. (str(epoch + 1) + " " + str(acc) + " " + str(running_loss) + '\r\n'))

    训练效果如下:

    1    0.3822000026702881    -1.1664808988571167
    
    2    0.41040000319480896    -1.1804652214050293
    
    3    0.46219998598098755    -1.2277178764343262
    
    4    0.4693000018596649    -1.302511215209961
    
    5    0.4846999943256378    -1.2454404830932617
    
    6    0.4932999908924103    -1.27120840549469

    本人还是大二的,对深度学习不是很深入了解,如果上面有错误,欢迎指正,请别互联网暴力哇,发这个博客纯属是为了给大伙一起学习新知识

  • 相关阅读:
    R语言 | 绘制带P值的差异柱状图
    EasyExcel表头校验方法
    【初阶数据结构】——堆排序和TopK问题
    877. 扩展欧几里得算法
    js字符串处理
    数据结构题型11-顺序队列
    如何在Python爬虫中使用IP代理以避免反爬虫机制
    力扣刷题19-删除链表的倒数第N个节点
    JDBC工具类
    Intel-Hex , Motorola S-Record 格式详细解析
  • 原文地址:https://blog.csdn.net/blockshowtouse/article/details/125551415