参考
https://github.com/KaiyangZhou/pytorch-center-loss
# raw implement
import torch
import torch.nn as nn
class CenterLoss(nn.Module):
"""Center loss.
Reference:
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
Args:
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (batch_size).
"""
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
return loss
# scJoint
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class CenterLoss(nn.Module):
def __init__(self, num_classes=20, feat_dim=64, use_gpu=True):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, embeddings, labels):
center_loss = 0
for i, x in enumerate(embeddings):
label = labels[i].long()
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())
distmat = torch.sqrt(distmat)
classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
label = label.unsqueeze(1).expand(batch_size, self.num_classes)
mask = label.eq(classes.expand(batch_size, self.num_classes))
dist = distmat * mask.float()
center_loss += torch.mean(dist.clamp(min=1e-12, max=1e+12))
#enter_loss = center_loss/len(embeddings) # 其实这个长度就是1,可以不用除
return center_loss
torch.manual_seed(42)
num_class=10
num_feature=64
num_sample=256
center_loss=CenterLoss(num_classes=num_class,feat_dim=num_feature,use_gpu=False)
embedding=np.random.randn(num_sample,num_feature)
label=np.random.randint(0,num_class,size=num_sample)
embeddings=[torch.FloatTensor(embedding)]
labels= [torch.LongTensor(label)]
print(center_loss(embeddings,labels))
#print(center_loss.centers)
结果如下
class CenterLoss2(nn.Module):
def __init__(self, num_class=10, num_feature=2):
super(CenterLoss2, self).__init__()
self.num_class = num_class
self.num_feature = num_feature
self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))
def forward(self, x, labels):
print("Centerloss2")
center = self.centers[labels]
dist = (x-center).pow(2).sum(dim=-1)
#########################
dist = torch.sqrt(dist)
#########################
loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)
return loss
torch.manual_seed(42)
center_loss2=CenterLoss2(num_class=num_class,num_feature=num_feature)
embeddings=torch.FloatTensor(embedding)
labels= torch.LongTensor(label)
print(center_loss2(embeddings,labels)/num_class)
#print(center_loss2.centers)
结果如下
可以看到结果一幕一样,就不用管第一种实现了