• 分类网络-类别不均衡问题之FocalLoss


    在这里插入图片描述
    有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
    train.py

    import torch
    import torchvision
    import time
    import argparse
    import importlib
    from loss import FocalLoss
    
    
    def parse_args():
        parser = argparse.ArgumentParser('training')
        parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
        parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
        parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        args = parse_args()
        batch_size = args.batch_size
        num_epochs = args.num_epochs
        model = importlib.import_module('models.'+args.model) 
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = model.net.to(device)
    
        loss = torch.nn.CrossEntropyLoss()
        
        if args.model == 'mlp':
            optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
        else:
            optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
              
        train_path = r'./Datasets/mnist_png/training'
        test_path = r'./Datasets/mnist_png/testing'
        transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
        if args.model == 'alexnet' or args.model == 'vgg':
            transform_list.append(torchvision.transforms.Resize(size=224))
        if args.model == 'googlenet' or args.model == 'resnet':
            transform_list.append(torchvision.transforms.Resize(size=96))
        transform = torchvision.transforms.Compose(transform_list)
    
        train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
        test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)
    
        train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
        for epoch in range(num_epochs):
            train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
            for X, y in train_iter:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                train_l += l.cpu().item()
                train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
                m += y.shape[0]
                batch_count += 1
            with torch.no_grad():
                for X, y in test_iter:
                    net.eval() # 评估模式
                    test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                    net.train() # 改回训练模式
                    n += y.shape[0]
            print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
            torch.save(net, args.model+".pth")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68

    test.py

    import cv2
    import torch
    import argparse
    import importlib
    from pathlib import Path
    import torchvision.transforms.functional
    
    
    def parse_args():
        parser = argparse.ArgumentParser('testing')
        parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        args = parse_args()
        model = importlib.import_module('models.' + args.model) 
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = model.net.to(device)
        net = torch.load(args.model+'.pth')
        net.eval()
    
        with torch.no_grad():
            imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
            acc = 0
            count = 0
            for img_path in imgs_path:
                img = cv2.imread(str(img_path), 0)
                if args.model == 'alexnet' or args.model == 'vgg':  
                    img = cv2.resize(img, (224,224))
                if args.model == 'googlenet' or args.model == 'resnet':
                    img = cv2.resize(img, (96,96))
                img_tensor = torchvision.transforms.functional.to_tensor(img)
                img_tensor = torch.unsqueeze(img_tensor, 0)
                #print(net(img_tensor.to(device)).argmax(dim=1).item())
                if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
                    acc += 1
                count+=1
        print(acc/count)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
    训练结果:

    epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
    epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
    epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
    epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
    epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.9780793319415448
    
    • 1

    现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
    训练结果:

    epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
    epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
    epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
    epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
    epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.04801670146137787
    
    • 1

    可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

    引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from typing import List, Optional, Union
    
    
    class FocalLoss(nn.Module):
        def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
            """
            :param alpha: 每个类别的权重
            :param gamma:
            :param with_logits: 是否经过softmax或者sigmoid
            """
            super(FocalLoss, self).__init__()
            self.gamma = gamma
            self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
            self.smooth = 1e-8
            self.with_logits = with_logits
    
        def _binary_class(self, input, target):
            prob = torch.sigmoid(input) if self.with_logits else input
            prob += self.smooth
            alpha = self.alpha.to(target.device)
            loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
            return loss
    
        def _multiple_class(self, input, target):
            prob = F.softmax(input, dim=1) if self.with_logits else input
    
            alpha = self.alpha.to(target.device)
            alpha = alpha.gather(0, target)
    
            target = target.view(-1, 1)
    
            prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
            logpt = torch.log(prob)
    
            loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
            return loss
    
        def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            """
            :param input: 维度为[bs, num_classes]
            :param target: 维度为[bs]
            :return:
            """
            if len(input.shape) > 1 and input.shape[-1] != 1:
                loss = self._multiple_class(input, target)
            else:
                loss = self._binary_class(input, target)
    
            return loss.mean()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52

    并将train.py的第26行修改成

        loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])
    
    • 1

    其中列表的数字代表10个类别的权重值。
    训练结果:

    epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
    epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
    epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
    epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
    epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试结果:

    0.9196242171189979
    
    • 1

    对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

  • 相关阅读:
    [附源码]JAVA毕业设计基于web的面向公众的食品安全知识系统(系统+LW)
    2022-08-21 星环科技-C++开发笔试
    技术学习方法分享
    服务器可靠性稳定性调优指引
    windows环境hadoop报错‘D:\Program‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。
    数据结构——时间复杂度
    11-散列2 Hashing(浙大数据结构PTA习题)
    风火编程--playwright爬虫
    对比HomeKit、米家,智汀家庭云版有哪些场景化的体验
    异步编程 - 05 基于JDK中的Future实现异步编程(中)_CompletableFuture
  • 原文地址:https://blog.csdn.net/taifyang/article/details/133954625