• 【Few Shot】batch-based 和 episodic-based 两种训练导入输入的区别


    前言

    现在做的情感计算方向数据量还是比较小的,所以想着后面可能会做 few shot 相关的内容,在读论文的时候注意到了这两种训练模式,所以从代码的角度来记录一下二者的不同。

    Batch-based

    这就是和普通的训练一样的数据导入,从数据集中选择一个 batch size 大小的子集,扔到模型里面学习训练。首先是定义一个Dataset,构造函数的关键字参数包含了数据集相关的内容信息,可以自己指定。

    class Dataset:
    	def __init__(self, **kwargs):
    		self.transform = build_transform()
    		self.im_path = kwargs['im_path']
    		self.labels = kwargs['labels']
    		...
    	
    	def __getitem__(self, idx):
    		im = Image.open(self.im_path[idx]).convert("RGB")
    		im = self.transform(im)
    		return im, self.labels[idx]
    	
    	def __len__(self):
    		return len(self.labels)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    然后利用上面建立的 Dataset 类创建一个 dataloader。

    class Dataloader:
    	def __init__(slef, *args, **kwargs):
    		...
    	
    	def get_dataloader(self, **kwargs):
    		transform = build_transform()
    		dataset = Dataset(kwargs)
    		dataloader = torch.utils.data.DataLoader(dataset, kwargs)
    		return dataloader
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    确实和普通的是一样的。

    Episodic-based

    这个就主要用在小样本学习里面了,
    首先定义 Dataset 类,这个和之前的也是一样,直接定义就行了。但是开源代码中在 Dataset 的 transform 中定义了采样支撑集和 query 的函数。

    def extract_episode(n_support, n_query, d):
        # data: N x C x H x W
        n_examples = d["data"].size(0)
    
        if n_query == -1:
            n_query = n_examples - n_support
    
        example_inds = torch.randperm(n_examples)[: (n_support + n_query)]
        support_inds = example_inds[:n_support]
        query_inds = example_inds[n_support:]
    
        xs = d["data"][support_inds]
        xq = d["data"][query_inds]
    
        return {"class": d["class"], "xs": xs, "xq": xq}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    然后在实例化 Dataset 的时候:

    transform = [
    	...
    	partial(extract_episode, n_support, n_query)
    ]
    dataset = Dataset(transform=transform, ...)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在参考的开源代码中,他们是定义了一个采样器 sampler,利用这个自定义的采样器在 DataLoader 中进行采样。

    class EpisodicBatchSampler(object):
        def __init__(self, n_classes, n_way, n_episodes):
            self.n_classes = n_classes
            self.n_way = n_way
            self.n_episodes = n_episodes
    
        def __len__(self):
            return self.n_episodes
    
        def __iter__(self):
            for i in range(self.n_episodes):
                yield torch.randperm(self.n_classes)[: self.n_way]
    
    
    dataloader = torch.utils.data.DataLoader(
    		ds, batch_sampler=sampler, num_workers=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在采样器的 __iter__ 函数中,就是从总共的类别数中选择 n_way 个类别。这样就采样到了一份训练数据。

    参考代码

    prototypical-networks
    CloserLookFewShot

  • 相关阅读:
    Nginx几种负载均衡方式介绍
    adb shell pm 查询设备应用
    【python】一篇玩转正则表达式
    热释电矢量传感器设计
    IP地址、子网掩码、默认网关介绍及例题计算
    Mybatis中如何传入map参数呢?
    MyBatis学习:按照位置传递参数
    SSH的在线音乐下载网站-JAVA【数据库设计、源码、开题报告】
    基于elementui input完成的输入控件
    14、Java 的方法重写详解
  • 原文地址:https://blog.csdn.net/m0_46304383/article/details/133919312