Knowledge Distillation Tutorial — PyTorch Tutorials 2.2.1+cu121 documentation
知识蒸馏的损失函数只接受两个相同维度的输入,所以我们需要采取措施使他们在进入损失函数之前是相同维度的。我们将使用平均池化层对在教师模型卷积运算后的logits进行池化,使得logits维度和学生保持一样。

原来的教师模型:
只有forward函数内有调整
- class DeepNN(nn.Module):
-
- def forward(self, x):
- x = self.features(x)
- x = torch.flatten(x, 1)
- x = self.classifier(x)
- return x
加入了池化后的教师模型:
- class ModifiedDeepNNCosine(nn.Module):
-
- def forward(self, x):
- x = self.features(x)
- flattened_conv_output = torch.flatten(x, 1)
- x = self.classifier(flattened_conv_output)
- flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
- return x, flattened_conv_output_after_pooling
原来的学生模型
- class LightNN(nn.Module):
-
- def forward(self, x):
- x = self.features(x)
- x = torch.flatten(x, 1)
- x = self.classifier(x)
- return x
更新后的学生模型:
- class ModifiedLightNNCosine(nn.Module):
-
- def forward(self, x):
- x = self.features(x)
- flattened_conv_output = torch.flatten(x, 1)
- x = self.classifier(flattened_conv_output)
- return x, flattened_conv_output

Teacher accuracy: 75.60% Student accuracy without teacher: 70.41% Student accuracy with CE + KD: 70.19% Student accuracy with CE + CosineLoss: 70.87% Student accuracy with CE + RegressorMSE: 71.40%