• 狗都能看懂的Pytorch MAML代码详解


    源码(觉得有用请点star,这对我很重要~)

    maml概念

    首先,我们需要说明的是maml不同于常见的训练方式。以猫狗分类和resnet作为例子,我们将猫狗分类定义为一个task,正常训练一个猫狗分类器,只需要输入猫和狗的图片去训练就好了。所以我们的一个batch中就会有多张猫或者狗的图片,这样训练出来的模型虽说可以预测这张图片是猫还是狗,但要想这个分类器有泛化性,就需要大量猫或狗的图像,而标注大量的数据是要成本的。

    现在我们假设一个场景,我们没有这么多猫狗分类的数据,但我们有其他task的数据。我们需要用少量的图像来训练一个强泛化性的模型。maml的训练方式允许我们用大量别的task的数据来得到一个初始化权重,这个初始化权重具有非常好的鲁棒性,仅用少量数据训练加上或者maml训练的初始化权重就可以达到和正常训练方式的准确率。

    为什么maml能做到这样的效果,请读者移步MAML原理讲解和代码实现。

    maml以task为单位,多个task组成一个batch,为了和正常训练方式区别开来,我们就成为meta-batch。以omniglot为例,如下图所示:

    在这里插入图片描述

    每个task之间都互相独立,都是不同的分类任务。

    数据读取

    这里为大家实现了个MAML数据读取的基类,用户只需要实现get_file_list和get_one_task_data两个函数即可。

    class MAMLDataset(Dataset):
    
        def __init__(self, data_path, batch_size, n_way=10, k_shot=2, q_query=1):
    
            self.file_list = self.get_file_list(data_path)
            self.batch_size = batch_size
            self.n_way = n_way
            self.k_shot = k_shot
            self.q_query = q_query
    
        def get_file_list(self, data_path):
            raise NotImplementedError('get_file_list function not implemented!')
    
        def get_one_task_data(self):
            raise NotImplementedError('get_one_task_data function not implemented!')
    
        def __len__(self):
            return len(self.file_list) // self.batch_size
    
        def __getitem__(self, index):
            return self.get_one_task_data()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    还是以omniglot为例,实现特殊数据集的子类数据读取的方法。

    get_file_list

    此函数要求得到一个所有task文件目录的list。比如一个总的文件夹中,下面有很多不同的task,这里因为omniglot数据命名比较统一,所以实现比较简单。

    get_one_task_data

    此函数要求返回一个task的数据,包括训练集和验证集,以下面代码为例,每次调用get_one_task_data时,返回一个n_way=5分类的task,其中训练集图像和标签的数量各为k_shot=1张,验证集图像和标签的数量各为q_query=1张。

    class OmniglotDataset(MAMLDataset):
        def get_file_list(self, data_path):
            """
            Get all fonts list.
            Args:
                data_path: Omniglot Data path
    
            Returns: fonts list
    
            """
            return [f for f in glob.glob(data_path + "**/character*", recursive=True)]
    
        def get_one_task_data(self):
            """
            Get ones task maml data, include one batch support images and labels, one batch query images and labels.
            Returns: support_data, query_data
    
            """
            img_dirs = random.sample(self.file_list, self.n_way)
            support_data = []
            query_data = []
    
            support_image = []
            support_label = []
            query_image = []
            query_label = []
    
            for label, img_dir in enumerate(img_dirs):
                img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)]
                images = random.sample(img_list, self.k_shot + self.q_query)
    
                # Read support set
                for img_path in images[:self.k_shot]:
                    image = Image.open(img_path)
                    image = np.array(image)
                    image = np.expand_dims(image / 255., axis=0)
                    support_data.append((image, label))
    
                # Read query set
                for img_path in images[self.k_shot:]:
                    image = Image.open(img_path)
                    image = np.array(image)
                    image = np.expand_dims(image / 255., axis=0)
                    query_data.append((image, label))
    
            # shuffle support set
            random.shuffle(support_data)
            for data in support_data:
                support_image.append(data[0])
                support_label.append(data[1])
    
            # shuffle query set
            random.shuffle(query_data)
            for data in query_data:
                query_image.append(data[0])
                query_label.append(data[1])
    
            return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    在调用Dataset的时候再使用torch的Dataloader包一下就好了,里面batch_size参数为任务的数量。相当于每训练1个step就要训练完这么多个task。

    train_dataset = OmniglotDataset(args.train_data_dir, args.task_num,
                                    n_way=args.n_way, k_shot=args.k_shot, q_query=args.q_query)
    val_dataset = OmniglotDataset(args.val_data_dir, args.val_task_num,
                                  n_way=args.n_way, k_shot=args.k_shot, q_query=args.q_query)
    train_loader = DataLoader(train_dataset, batch_size=args.task_num, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_dataset, batch_size=args.val_task_num, shuffle=False, num_workers=args.num_workers)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    模型训练

    代码如下:

    def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True):
        """
        Train the model using MAML method.
        Args:
            model: Any model
            support_images: several task support images
            support_labels: several  support labels
            query_images: several query images
            query_labels: several query labels
            inner_step: support data training step
            args: ArgumentParser
            optimizer: optimizer
            is_train: whether train
    
        Returns: meta loss, meta accuracy
    
        """
        meta_loss = []
        meta_acc = []
    
        for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):
    
            fast_weights = collections.OrderedDict(model.named_parameters())
            for _ in range(inner_step):
                # Update weight
                support_logit = model.functional_forward(support_image, fast_weights)
                support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
                grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True)
                fast_weights = collections.OrderedDict((name, param - args.inner_lr * grads)
                                                       for ((name, param), grads) in zip(fast_weights.items(), grads))
    
            # Use trained weight to get query loss
            query_logit = model.functional_forward(query_image, fast_weights)
            query_prediction = torch.max(query_logit, dim=1)[1]
    
            query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
            query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
    
            meta_loss.append(query_loss)
            meta_acc.append(query_acc.data.cpu().numpy())
    
        # Zero the gradient
        optimizer.zero_grad()
        meta_loss = torch.stack(meta_loss).mean()
        meta_acc = np.mean(meta_acc)
    
        if is_train:
            meta_loss.backward()
            optimizer.step()
    
        return meta_loss, meta_acc
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    support_images, support_labels, query_images, query_labels传入的都是以task为单位的,所以要用一个for循环来进行拆包,注意support_data和query_data数据集来源必须得一致,不能一个数据A task,另一个属于B task。

    拆包完之后,首先进行训练集的训练,我们要注意,此时的训练是不能改动到模型权重,但我们又需要知道它的训练方向,所以我们需要copy出来一个权重,让它执行训练,用这个得到的权重对query_data执行前向传播,以此得到的loss再进行反向传播优化。这个过程很绕,建议多读几遍源码就懂了。

    模型定义

    class Classifier(nn.Module):
        def __init__(self, in_ch, n_way):
            super(Classifier, self).__init__()
            self.conv1 = ConvBlock(in_ch, 64)
            self.conv2 = ConvBlock(64, 64)
            self.conv3 = ConvBlock(64, 64)
            self.conv4 = ConvBlock(64, 64)
            self.logits = nn.Linear(64, n_way)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.conv4(x)
            x = x.view(x.shape[0], -1)
            x = self.logits(x)
    
            return x
    
        def functional_forward(self, x, params):
            x = ConvBlockFunction(x, params[f'conv1.conv2d.weight'], params[f'conv1.conv2d.bias'],
                                  params.get(f'conv1.bn.weight'), params.get(f'conv1.bn.bias'))
            x = ConvBlockFunction(x, params[f'conv2.conv2d.weight'], params[f'conv2.conv2d.bias'],
                                  params.get(f'conv2.bn.weight'), params.get(f'conv2.bn.bias'))
            x = ConvBlockFunction(x, params[f'conv3.conv2d.weight'], params[f'conv3.conv2d.bias'],
                                  params.get(f'conv3.bn.weight'), params.get(f'conv3.bn.bias'))
            x = ConvBlockFunction(x, params[f'conv4.conv2d.weight'], params[f'conv4.conv2d.bias'],
                                  params.get(f'conv4.bn.weight'), params.get(f'conv4.bn.bias'))
    
            x = x.view(x.shape[0], -1)
            x = F.linear(x, params['logits.weight'], params['logits.bias'])
    
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    模型定义比较简单,maml思想主要是个训练方式,和模型本身无关。但我们在刚刚模型训练的时候有一些特殊操作,所以要定义一个functional_forward,这个函数要求实现和模型一样结构的网络,同时参数输入为:1、图像 2、权重。这样就可以保证得到了loss,但模型权重没有被修改。

  • 相关阅读:
    C++产生未定义的行为的原因分析
    飞桨Paddle动转静@to_static技术设计
    catkin_make编译链接不到libGL.so文件
    [深入研究4G/5G/6G专题-46]: L3信令控制-2-软件功能与流程的切分-DU网元的信令
    vue-cli 初始----安装运行Vue项目
    MySQL使用教程(基础篇03)
    UTF-16编码原理讲解
    通达OA 首页门户工作台
    第十二章 Spring Cloud Config 统一配置中心详解-客户端动态刷新
    Linux编译器-gcc/g++使用&函数库
  • 原文地址:https://blog.csdn.net/weixin_42392454/article/details/126127914