• 分类网络-类别不均衡问题之FocalLoss


    在这里插入图片描述
    有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
    train.py

    import torch
    import torchvision
    import time
    import argparse
    import importlib
    from loss import FocalLoss
    
    
    def parse_args():
        parser = argparse.ArgumentParser('training')
        parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
        parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
        parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        args = parse_args()
        batch_size = args.batch_size
        num_epochs = args.num_epochs
        model = importlib.import_module('models.'+args.model) 
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = model.net.to(device)
    
        loss = torch.nn.CrossEntropyLoss()
        
        if args.model == 'mlp':
            optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
        else:
            optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
              
        train_path = r'./Datasets/mnist_png/training'
        test_path = r'./Datasets/mnist_png/testing'
        transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
        if args.model == 'alexnet' or args.model == 'vgg':
            transform_list.append(torchvision.transforms.Resize(size=224))
        if args.model == 'googlenet' or args.model == 'resnet':
            transform_list.append(torchvision.transforms.Resize(size=96))
        transform = torchvision.transforms.Compose(transform_list)
    
        train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
        test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)
    
        train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
        for epoch in range(num_epochs):
            train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
            for X, y in train_iter:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                train_l += l.cpu().item()
                train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
                m += y.shape[0]
                batch_count += 1
            with torch.no_grad():
                for X, y in test_iter:
                    net.eval() # 评估模式
                    test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                    net.train() # 改回训练模式
                    n += y.shape[0]
            print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
            torch.save(net, args.model+".pth")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68

    test.py

    import cv2
    import torch
    import argparse
    import importlib
    from pathlib import Path
    import torchvision.transforms.functional
    
    
    def parse_args():
        parser = argparse.ArgumentParser('testing')
        parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        args = parse_args()
        model = importlib.import_module('models.' + args.model) 
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = model.net.to(device)
        net = torch.load(args.model+'.pth')
        net.eval()
    
        with torch.no_grad():
            imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
            acc = 0
            count = 0
            for img_path in imgs_path:
                img = cv2.imread(str(img_path), 0)
                if args.model == 'alexnet' or args.model == 'vgg':  
                    img = cv2.resize(img, (224,224))
                if args.model == 'googlenet' or args.model == 'resnet':
                    img = cv2.resize(img, (96,96))
                img_tensor = torchvision.transforms.functional.to_tensor(img)
                img_tensor = torch.unsqueeze(img_tensor, 0)
                #print(net(img_tensor.to(device)).argmax(dim=1).item())
                if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
                    acc += 1
                count+=1
        print(acc/count)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
    训练结果:

    epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
    epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
    epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
    epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
    epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.9780793319415448
    
    • 1

    现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
    训练结果:

    epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
    epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
    epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
    epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
    epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.04801670146137787
    
    • 1

    可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

    引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from typing import List, Optional, Union
    
    
    class FocalLoss(nn.Module):
        def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
            """
            :param alpha: 每个类别的权重
            :param gamma:
            :param with_logits: 是否经过softmax或者sigmoid
            """
            super(FocalLoss, self).__init__()
            self.gamma = gamma
            self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
            self.smooth = 1e-8
            self.with_logits = with_logits
    
        def _binary_class(self, input, target):
            prob = torch.sigmoid(input) if self.with_logits else input
            prob += self.smooth
            alpha = self.alpha.to(target.device)
            loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
            return loss
    
        def _multiple_class(self, input, target):
            prob = F.softmax(input, dim=1) if self.with_logits else input
    
            alpha = self.alpha.to(target.device)
            alpha = alpha.gather(0, target)
    
            target = target.view(-1, 1)
    
            prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
            logpt = torch.log(prob)
    
            loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
            return loss
    
        def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            """
            :param input: 维度为[bs, num_classes]
            :param target: 维度为[bs]
            :return:
            """
            if len(input.shape) > 1 and input.shape[-1] != 1:
                loss = self._multiple_class(input, target)
            else:
                loss = self._binary_class(input, target)
    
            return loss.mean()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52

    并将train.py的第26行修改成

        loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])
    
    • 1

    其中列表的数字代表10个类别的权重值。
    训练结果:

    epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
    epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
    epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
    epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
    epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.9196242171189979
    
    • 1

    对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

  • 相关阅读:
    pandas使用str函数和contains函数删除dataframe中单个指定字符串数据列包含特定字符串列表中的其中任何一个字符串的数据行
    java-php-python-ssm文献管理平台计算机毕业设计
    基于 VMware workstation 16 安装 Linux CentOS 8 操作系统(超详细教程)
    d2-crud-plus 使用小技巧(六)—— 表单下拉选择 行样式 溢出时显示异常优化
    第 4 章 串(串的堆分配存储实现)
    微服务节流控制:Eureka中服务速率限制的精妙配置
    【银角大王——Django课程——创建项目+部门表的基本操作】
    Docker命令速查
    SiC,GaN驱动优选驱动方案SiLM5350系列SiLM5350MDDCM-DG 带米勒钳位Clamp保护功能 单通道隔离栅极驱动器
    迷宫生成与路径规划算法-Python3.8-附Github代码
  • 原文地址:https://blog.csdn.net/taifyang/article/details/133954625