• center loss pytorch实现总结


    实现1

    参考
    https://github.com/KaiyangZhou/pytorch-center-loss

    
    # raw implement
    import torch
    import torch.nn as nn
    
    class CenterLoss(nn.Module):
        """Center loss.
        
        Reference:
        Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
        
        Args:
            num_classes (int): number of classes.
            feat_dim (int): feature dimension.
        """
        def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
            super(CenterLoss, self).__init__()
            self.num_classes = num_classes
            self.feat_dim = feat_dim
            self.use_gpu = use_gpu
    
            if self.use_gpu:
                self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
            else:
                self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
    
        def forward(self, x, labels):
            """
            Args:
                x: feature matrix with shape (batch_size, feat_dim).
                labels: ground truth labels with shape (batch_size).
            """
            batch_size = x.size(0)
            distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                      torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
            distmat.addmm_(1, -2, x, self.centers.t())
    
            classes = torch.arange(self.num_classes).long()
            if self.use_gpu: classes = classes.cuda()
            labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
            mask = labels.eq(classes.expand(batch_size, self.num_classes))
    
            dist = distmat * mask.float()
            loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
    
            return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    # scJoint
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    import numpy as np
    
    
    class CenterLoss(nn.Module):
        def __init__(self, num_classes=20, feat_dim=64, use_gpu=True):
            super(CenterLoss, self).__init__()
            self.num_classes = num_classes
            self.feat_dim = feat_dim
            self.use_gpu = use_gpu
    
            if self.use_gpu:
                self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
            else:
                self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
    
        def forward(self, embeddings, labels):
            center_loss = 0
            for i, x in enumerate(embeddings):
                label = labels[i].long()
                batch_size = x.size(0)
                distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                          torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
                distmat.addmm_(1, -2, x, self.centers.t())
                distmat = torch.sqrt(distmat)
    
                classes = torch.arange(self.num_classes).long()
                if self.use_gpu: classes = classes.cuda()
                label = label.unsqueeze(1).expand(batch_size, self.num_classes)
                mask = label.eq(classes.expand(batch_size, self.num_classes))
    
                dist = distmat * mask.float()
                center_loss += torch.mean(dist.clamp(min=1e-12, max=1e+12))
            
            #enter_loss = center_loss/len(embeddings) # 其实这个长度就是1,可以不用除
            return center_loss
        
    
    torch.manual_seed(42)
    num_class=10
    num_feature=64
    num_sample=256
    
    
    center_loss=CenterLoss(num_classes=num_class,feat_dim=num_feature,use_gpu=False)
    embedding=np.random.randn(num_sample,num_feature)
    label=np.random.randint(0,num_class,size=num_sample)
    
    embeddings=[torch.FloatTensor(embedding)]
    labels= [torch.LongTensor(label)]
    
    print(center_loss(embeddings,labels))
    
    #print(center_loss.centers)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    结果如下
    在这里插入图片描述

    实现2(好理解)

    class CenterLoss2(nn.Module):
        def __init__(self, num_class=10, num_feature=2):
            super(CenterLoss2, self).__init__()
            self.num_class = num_class
            self.num_feature = num_feature
            self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))
    
        def forward(self, x, labels):
            print("Centerloss2")
            center = self.centers[labels]
            dist = (x-center).pow(2).sum(dim=-1)
            #########################
            dist = torch.sqrt(dist)
            #########################
            loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)
    
            return loss
    
    
    torch.manual_seed(42)
    center_loss2=CenterLoss2(num_class=num_class,num_feature=num_feature)    
    
    embeddings=torch.FloatTensor(embedding)
    labels= torch.LongTensor(label)
    
    print(center_loss2(embeddings,labels)/num_class)
    #print(center_loss2.centers)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    结果如下
    在这里插入图片描述可以看到结果一幕一样,就不用管第一种实现了

  • 相关阅读:
    数字孪生发展阶段报告与背景规划,面临数字孪生机遇。
    [附源码]计算机毕业设计SpringBoot仓储综合管理系统
    R语言使用aov函数执行单因素方差分析、使用TukeyHSD函数分析单因素方差分析的结果并解读TukeyHSD函数的输出结果
    【MySQL】CRUD (增删改查) 基础
    LeetCode 73. 矩阵置零(java实现)
    开发者配置项、开发者选项自定义
    AAOS CarPowerManager
    全网最新的jmeter压测话,只想快速教会你用Jmeter编写脚本进行压测
    冒泡排序和鸡尾酒排序和快速排序
    vue3父组件提交校验多个子组件
  • 原文地址:https://blog.csdn.net/qq_45759229/article/details/126917939