• 知识蒸馏实战:使用CoatNet蒸馏ResNet


    摘要

    知识蒸馏(Knowledge Distillation),简称KD,将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去。Hinton在"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。论文链接:https://arxiv.org/pdf/1503.02531.pdf

    蒸馏的过程

    知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

    • 原始模型训练: 训练"Teacher模型", 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
    • 精简模型训练: 训练"Student模型", 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。
    • Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

    在这里插入图片描述

    最终结论

    先把结论说了吧! Teacher网络使用coatnet_2,Student网络使用ResNet18。如下表

    网络epochsACC
    coatnet_25092%
    ResNet185086%
    ResNet18 +KD5089%

    在相同的条件下,加入知识蒸馏后,ResNet18的ACC上升了3个点,提升的还是很高的。如下图:
    在这里插入图片描述

    数据准备

    数据使用我以前在图像分类任务中的数据集——植物幼苗数据集,先将数据集转为训练集和验证集。执行代码:

    import glob
    import os
    import shutil
    
    image_list=glob.glob('data1/*/*.png')
    print(image_list)
    file_dir='data'
    if os.path.exists(file_dir):
        print('true')
        #os.rmdir(file_dir)
        shutil.rmtree(file_dir)#删除再建立
        os.makedirs(file_dir)
    else:
        os.makedirs(file_dir)
    
    from sklearn.model_selection import train_test_split
    trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
    train_dir='train'
    val_dir='val'
    train_root=os.path.join(file_dir,train_dir)
    val_root=os.path.join(file_dir,val_dir)
    for file in trainval_files:
        file_class=file.replace("\\","/").split('/')[-2]
        file_name=file.replace("\\","/").split('/')[-1]
        file_class=os.path.join(train_root,file_class)
        if not os.path.isdir(file_class):
            os.makedirs(file_class)
        shutil.copy(file, file_class + '/' + file_name)
    
    for file in val_files:
        file_class=file.replace("\\","/").split('/')[-2]
        file_name=file.replace("\\","/").split('/')[-1]
        file_class=os.path.join(val_root,file_class)
        if not os.path.isdir(file_class):
            os.makedirs(file_class)
        shutil.copy(file, file_class + '/' + file_name)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    教师网络

    教师网络选用coatnet_2,是一个比较大一点的网络了,模型的大小有200M。训练50个epoch,最好的模型在92%左右。

    步骤

    新建teacher_train.py,插入代码:

    导入需要的库

    import torch.optim as optim
    import torch
    import torch.nn as nn
    import torch.nn.parallel
    import torch.utils.data
    import torch.utils.data.distributed
    import torchvision.transforms as transforms
    from torchvision import datasets
    from torch.autograd import Variable
    from model.coatnet import coatnet_2
    
    import json
    import os
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    定义训练和验证函数

    
    def train(model, device, train_loader, optimizer, epoch):
        model.train()
        sum_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print_loss = loss.data.item()
            sum_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item()))
        ave_loss = sum_loss / len(train_loader)
        print('epoch:{},loss:{}'.format(epoch, ave_loss))
    
    Best_ACC=0
    # 验证过程
    @torch.no_grad()
    def val(model, device, test_loader):
        global Best_ACC
        model.eval()
        test_loss = 0
        correct = 0
        total_num = len(test_loader.dataset)
        print(total_num, len(test_loader))
        with torch.no_grad():
            for data, target in test_loader:
                data, target = Variable(data).to(device), Variable(target).to(device)
                output = model(data)
                loss = criterion(output, target)
                _, pred = torch.max(output.data, 1)
                correct += torch.sum(pred == target)
                print_loss = loss.data.item()
                test_loss += print_loss
            correct = correct.data.item()
            acc = correct / total_num
            avgloss = test_loss / len(test_loader)
            if acc > Best_ACC:
                torch.save(model, file_dir + '/' + 'best.pth')
                Best_ACC = acc
            print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                avgloss, correct, len(test_loader.dataset), 100 * acc))
            return acc
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    定义全局参数

    if __name__ == '__main__':
        # 创建保存模型的文件夹
        file_dir = 'CoatNet'
        if os.path.exists(file_dir):
            print('true')
    
            os.makedirs(file_dir, exist_ok=True)
        else:
            os.makedirs(file_dir)
    
        # 设置全局参数
        modellr = 1e-4
        BATCH_SIZE = 16
        EPOCHS = 50
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    图像预处理与增强

     # 数据预处理7
        transform = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    
        ])
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
        ])
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    读取数据

    使用pytorch默认读取数据的方式。

        # 读取数据
        dataset_train = datasets.ImageFolder('data/train', transform=transform)
        dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
        with open('class.txt', 'w') as file:
            file.write(str(dataset_train.class_to_idx))
        with open('class.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(dataset_train.class_to_idx))
        # 导入数据
        train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    设置模型和Loss

       # 实例化模型并且移动到GPU
        criterion = nn.CrossEntropyLoss()
    
        model_ft = coatnet_2()
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 12)
        model_ft.to(DEVICE)
        # 选择简单暴力的Adam优化器,学习率调低
        optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
        cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
        # 训练
        val_acc_list= {}
        for epoch in range(1, EPOCHS + 1):
            train(model_ft, DEVICE, train_loader, optimizer, epoch)
            cosine_schedule.step()
            acc=val(model_ft, DEVICE, test_loader)
            val_acc_list[epoch]=acc
            with open('result.json', 'w', encoding='utf-8') as file:
                file.write(json.dumps(val_acc_list))
        torch.save(model_ft, 'CoatNet/model_final.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    完成上面的代码就可以开始训练Teacher网络了。

    学生网络

    学生网络选用ResNet18,是一个比较小一点的网络了,模型的大小有40M。训练50个epoch,最好的模型在86%左右。

    步骤

    新建student_train.py,插入代码:

    导入需要的库

    import torch.optim as optim
    import torch
    import torch.nn as nn
    import torch.nn.parallel
    import torch.utils.data
    import torch.utils.data.distributed
    import torchvision.transforms as transforms
    from torchvision import datasets
    from torch.autograd import Variable
    from torchvision.models.resnet import resnet18
    
    import json
    import os
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    定义训练和验证函数

    
    def train(model, device, train_loader, optimizer, epoch):
        model.train()
        sum_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print_loss = loss.data.item()
            sum_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item()))
        ave_loss = sum_loss / len(train_loader)
        print('epoch:{},loss:{}'.format(epoch, ave_loss))
    
    Best_ACC=0
    # 验证过程
    @torch.no_grad()
    def val(model, device, test_loader):
        global Best_ACC
        model.eval()
        test_loss = 0
        correct = 0
        total_num = len(test_loader.dataset)
        print(total_num, len(test_loader))
        with torch.no_grad():
            for data, target in test_loader:
                data, target = Variable(data).to(device), Variable(target).to(device)
                output = model(data)
                loss = criterion(output, target)
                _, pred = torch.max(output.data, 1)
                correct += torch.sum(pred == target)
                print_loss = loss.data.item()
                test_loss += print_loss
            correct = correct.data.item()
            acc = correct / total_num
            avgloss = test_loss / len(test_loader)
            if acc > Best_ACC:
                torch.save(model, file_dir + '/' + 'best.pth')
                Best_ACC = acc
            print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                avgloss, correct, len(test_loader.dataset), 100 * acc))
            return acc
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    定义全局参数

    if __name__ == '__main__':
        # 创建保存模型的文件夹
        file_dir = 'resnet'
        if os.path.exists(file_dir):
            print('true')
    
            os.makedirs(file_dir, exist_ok=True)
        else:
            os.makedirs(file_dir)
    
        # 设置全局参数
        modellr = 1e-4
        BATCH_SIZE = 16
        EPOCHS = 50
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    图像预处理与增强

     # 数据预处理7
        transform = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    
        ])
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
        ])
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    读取数据

    使用pytorch默认读取数据的方式。

        # 读取数据
        dataset_train = datasets.ImageFolder('data/train', transform=transform)
        dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
        with open('class.txt', 'w') as file:
            file.write(str(dataset_train.class_to_idx))
        with open('class.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(dataset_train.class_to_idx))
        # 导入数据
        train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    设置模型和Loss

    	# 实例化模型并且移动到GPU
        criterion = nn.CrossEntropyLoss()
        model_ft = resnet18()
        print(model_ft)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 12)
        model_ft.to(DEVICE)
        # 选择简单暴力的Adam优化器,学习率调低
        optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
        cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
        # 训练
        val_acc_list= {}
        for epoch in range(1, EPOCHS + 1):
            train(model_ft, DEVICE, train_loader, optimizer, epoch)
            cosine_schedule.step()
            acc=val(model_ft, DEVICE, test_loader)
            val_acc_list[epoch]=acc
            with open('result_student.json', 'w', encoding='utf-8') as file:
                file.write(json.dumps(val_acc_list))
        torch.save(model_ft, 'resnet/model_final.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    完成上面的代码就可以开始训练Student网络了。

    蒸馏学生网络

    学生网络继续选用ResNet18,使用Teacher网络蒸馏学生网络,训练50个epoch,最终ACC是89%。

    步骤

    新建student_kd_train.py,插入代码:

    导入需要的库

    import torch.optim as optim
    import torch
    import torch.nn as nn
    import torch.nn.parallel
    import torch.utils.data
    import torch.utils.data.distributed
    import torchvision.transforms as transforms
    from torchvision import datasets
    from torch.autograd import Variable
    from torchvision.models.resnet import resnet18
    
    import json
    import os
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    定义蒸馏函数

    def distillation(y, labels, teacher_scores, temp, alpha):
        return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
                temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
    
    • 1
    • 2
    • 3

    定义训练和验证函数

    
    # 定义训练过程
    def train(model, device, train_loader, optimizer, epoch):
        model.train()
        sum_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            teacher_output = teacher_model(data)  # 训练出教师的 teacher_output
            teacher_output = teacher_output.detach()  # 切断老师网络的反向传播
            loss = distillation(output, target, teacher_output, temp=7.0, alpha=0.7)  # 通过老师的 teacher_output训练学生的output
    
            loss.backward()
            optimizer.step()
            print_loss = loss.data.item()
            sum_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item()))
        ave_loss = sum_loss / len(train_loader)
        print('epoch:{},loss:{}'.format(epoch, ave_loss))
    
    Best_ACC=0
    # 验证过程
    @torch.no_grad()
    def val(model, device, test_loader):
        global Best_ACC
        model.eval()
        test_loss = 0
        correct = 0
        total_num = len(test_loader.dataset)
        print(total_num, len(test_loader))
        with torch.no_grad():
            for data, target in test_loader:
                data, target = Variable(data).to(device), Variable(target).to(device)
                output = model(data)
                loss = criterion(output, target)
                _, pred = torch.max(output.data, 1)
                correct += torch.sum(pred == target)
                print_loss = loss.data.item()
                test_loss += print_loss
            correct = correct.data.item()
            acc = correct / total_num
            avgloss = test_loss / len(test_loader)
            if acc > Best_ACC:
                torch.save(model, file_dir + '/' + 'best.pth')
                Best_ACC = acc
            print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                avgloss, correct, len(test_loader.dataset), 100 * acc))
            return acc
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    定义全局参数

    if __name__ == '__main__':
        # 创建保存模型的文件夹
        file_dir = 'resnet_kd'
        if os.path.exists(file_dir):
            print('true')
    
            os.makedirs(file_dir, exist_ok=True)
        else:
            os.makedirs(file_dir)
    
        # 设置全局参数
        modellr = 1e-4
        BATCH_SIZE = 16
        EPOCHS = 50
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    图像预处理与增强

     # 数据预处理7
        transform = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    
        ])
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
        ])
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    读取数据

    使用pytorch默认读取数据的方式。

        # 读取数据
        dataset_train = datasets.ImageFolder('data/train', transform=transform)
        dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
        with open('class.txt', 'w') as file:
            file.write(str(dataset_train.class_to_idx))
        with open('class.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(dataset_train.class_to_idx))
        # 导入数据
        train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    设置模型和Loss

    	# 实例化模型并且移动到GPU
        criterion = nn.CrossEntropyLoss()
        model_ft = resnet18()
        print(model_ft)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 12)
        model_ft.to(DEVICE)
        # 选择简单暴力的Adam优化器,学习率调低
        optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
        cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
        # 训练
        val_acc_list= {}
        for epoch in range(1, EPOCHS + 1):
            train(model_ft, DEVICE, train_loader, optimizer, epoch)
            cosine_schedule.step()
            acc=val(model_ft, DEVICE, test_loader)
            val_acc_list[epoch]=acc
            with open('result_student.json', 'w', encoding='utf-8') as file:
                file.write(json.dumps(val_acc_list))
        torch.save(model_ft, 'resnet_kd/model_final.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    完成上面的代码就可以开始蒸馏模式!!!

    结果比对

    加载保存的结果,然后绘制acc曲线。

    import numpy as np
    from matplotlib import pyplot as plt
    import json
    teacher_file='result.json'
    student_file='result_student.json'
    student_kd_file='result_kd.json'
    def read_json(file):
        with open(file, 'r', encoding='utf8') as fp:
            json_data = json.load(fp)
            print(json_data)
        return json_data
    
    teacher_data=read_json(teacher_file)
    student_data=read_json(student_file)
    student_kd_data=read_json(student_kd_file)
    
    
    x =[int(x) for x in  list(dict(teacher_data).keys())]
    print(x)
    
    plt.plot(x, list(teacher_data.values()), label='teacher')
    plt.plot(x,list(student_data.values()), label='student without KD')
    plt.plot(x, list(student_kd_data.values()), label='student with KD')
    
    plt.title('Test accuracy')
    plt.legend()
    
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    总结

    知识蒸馏是常用的一种对轻量化模型压缩和提升的方法。今天通过一个简单的例子讲解了如何使用Teacher网络对Student网络进行蒸馏。

    本次用到的代码和数据集:

    https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/86947893

    码字不易,欢迎大家点赞评论收藏!

  • 相关阅读:
    通过逻辑回归和感知器算法对乳腺癌数据集breastCancer和鸢尾花数据集iris进行线性分类
    老站长带你全面认识基站和天线
    ① 尚品汇的前台开发笔记【尚硅谷】【Vue】
    swift语言用哪种库适合做爬虫?
    AI视频教程下载:ChatGPT个人生产力提升指南
    51单片机实验:数码管动态显示00-99
    Linux - 正则表达式
    140.【鸿蒙OS开发-01】
    【Linux】计算机的软硬件体系结构
    【洛谷 P1596】[USACO10OCT] Lake Counting S 题解(深度优先搜索)
  • 原文地址:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/127787791