• 原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4


    在这里插入图片描述


    原型网络进行分类的基本流程

    利用原型网络进行分类,基本流程如下:

    1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
    2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
    3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
    4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
    5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

    一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

    def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算
    	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)
    
    class Protonets(object):
    	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
    		#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型
    		self.input_shape = input_shape
    		self.outDim = outDim
    		self.batchSize = 1
    		self.Ns = Ns
    		self.Nq = Nq
    		self.Nc = Nc
    		if trainval == False:
    			#若训练一个新的模型,初始化CNN和中心点
    			self.center = {}
    			self.model = CNNnet(input_shape,outDim)
    		else:
    			#否则加载CNN模型和中心点
    			self.center = {}
    			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''
    			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''
    	
    	def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
    		center = 0
    		for i in range(self.Ns):
    			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
    			data = Variable(torch.from_numpy(data))
    			data = self.model(data)[0]	#将查询点嵌入另一个空间
    			if i == 0:
    				center = data
    			else:
    				center += data
    		center /= self.Ns
    		return center
    	
    	def train(self,labels_data,class_number):	#网络的训练
    		#Select class indices for episode
    		class_index = list(range(class_number))
    		random.shuffle(class_index)
    		choss_class_index = class_index[:self.Nc]#选20个类
    		sample = {'xc':[],'xq':[]}
    		for label in choss_class_index:
    			D_set = labels_data[label]
    			#从D_set随机取支持集和查询集
    			support_set,query_set = self.randomSample(D_set)
    			#计算中心点
    			self.center[label] = self.compute_center(support_set)
    			#将中心和查询集存储在list中
    			sample['xc'].append(self.center[label])	#list
    			sample['xq'].append(query_set)
    		#优化器
    		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
    		optimizer.zero_grad()
    		protonets_loss = self.loss(sample)
    		protonets_loss.backward()
    		optimizer.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    二、每一行代码的详细解释

    def eucli_tensor(x, y):
        return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)
    
    • 1
    • 2

    这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

    class Protonets(object):
        def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):
            self.input_shape = input_shape
            self.outDim = outDim
            self.batchSize = 1
            self.Ns = Ns
            self.Nq = Nq
            self.Nc = Nc
            if trainval == False:
                self.center = {}
                self.model = CNNnet(input_shape, outDim)
            else:
                self.center = {}
                self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')
                self.load_center(log_data + 'model_center_' + str(step) + '.csv')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

    • input_shape:输入数据的形状。
    • outDim:输出维度。
    • Ns:支持集(support set)的数量。
    • Nq:查询集(query set)的数量。
    • Nc:每次迭代所选类别数。
    • log_data:模型和中心的存储位置。
    • step:训练的步数。
    • trainval:是否重新开始训练模型。

    根据 trainval 的取值,分为两种情况进行初始化:

    1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
    2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

    总结

    这段代码是一个用于实现 Protonets 算法的类。

  • 相关阅读:
    MySQL慢查询优化、日志收集定位排查、慢查询sql分析
    【一】情感对话 Towards Emotional Support Dialog Systems 论文阅读
    Java 动态加载字节码
    【Java基础】Java导Excel攻略
    Pulsar Manager和dashboard部署和启用认证
    UE5 ChaosVehicles载具 实现大漂移 (连载四)
    【微信小程序入门到精通】— 微信小程序开发工具的安装
    使用python爬虫语言调用有道翻译实现英中互译(2023实现)
    大语言模型(LLM)综述(七):大语言模型设计应用与未来方向
    14道高频手写JS面试题及答案,巩固你的JS基础
  • 原文地址:https://blog.csdn.net/qlkaicx/article/details/134479313