• 【论文阅读】Prototypical Networks for Few-shot Learning


    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


    前言

    本文结合论文youtube上的视频[Few-shot learning][2.2] Prototypical Networks: intuition, algorithm, pytorch code来整理一下对prototypical networks在few-shot领域的理解。


    一、论文

    摘要

    问题:少镜头分类问题(在只给定少量实例的情况下,分类器必须推广到未看到的新类)。
    提出的解决方案:Prototypical Networks学习一个度量空间,在该空间中,可以通过计算得到每个类的原型表示的距离来执行分类。
    优点:与最近少镜头学习方法相比,它们反映了一种更简单的归纳偏差,这种偏差在优先数据的状态下是有益的,取得了出色的结果。
    分析:我们表明一些简单的设计决策可以比最近涉及复杂架构选择和元学习的方法产生实质性的改进。
    扩展:扩展到了0样本学习,在CU-Birds dataset中获得了最先进的结果。

    方法

    在这里插入图片描述
    这是在度量空间中,左边是few-shot,是计算每个类的embedded支持例的平均值得到ck。右边是zero-shot,通过embedding类别元数据vk生成的。在每一种情况下,embedded查询点是通过softmax对类原型的距离进行分类。pφ(y = k|x) ∝ exp(−d(fφ(x), ck)).

    ck中心就是每一个类通过embedding函数得到的支持点的平均值。就是embedding相当于一个有很多维的一个空间中的一个点(我觉得类似特征提取得到得特征,这些特征得到的相当于一个高维空间中得坐标,每个类的支持点坐标不一定相同但是相近,它们的平均值可以近似看作这个类在这个高维空间中聚类的那个中心点)。
    就是属于哪个类的概率p的计算是通过softmax函数得到的。p(y=k|x)是到自己true类别的距离的相反数的exp()比到其他类别距离的相反数的exp的和。loss就是-log(p(y=k|x))。
    下面是loss的计算。
    在这里插入图片描述
    距离:距离计算有很多公式,对于一类特定的距离函数,称为正则布雷格曼散度[4],原型网络算法等效于对具有指数族密度的支持集执行混合密度估计。
    原型计算可以从支持集上的硬聚类来看,每个类一个聚类,每个支持点分配给其相应的类聚类。对于布雷格曼散度,已经表明[4],达到到其指定点的最小距离的聚类代表是聚类均值。因此,当使用布雷格曼散度时,公式(1)中的原型计算在给定支持集标签的情况下产生最优聚类代表。所以他才取的均值。
    后面就是对指数组混合模型的一些数学公式,我暂时看不懂。

    重新解释为线性模型
    当我们使用欧几里得距离 d(z, z′) = ‖z − z′‖2 时,方程 (2) 中的模型等效于具有特定参数化的线性模型 [21]。若要查看此内容,请展开指数中的项:
    − ‖ f φ ( x ) − c k ‖ 2 = − f φ ( x ) T f φ ( x ) + 2 c k T f φ ( x ) − c k T c k −‖fφ(x) − c_k‖2 = −fφ(x)^Tfφ(x) + 2c^T_k fφ(x) − c^T_k c_k fφ(x)ck‖2=fφ(x)Tfφ(x)+2ckTfφ(x)ckTck
    等式中的第一项相对于类k是常数,所以他就变成线性的函数了。
    2 c k T f φ ( x ) − c k T c k = w k T f φ ( x ) + b k , w h e r e w k = 2 c k a n d b k = − c k T c k 2c^T_k fφ(x) − c^T_k ck = w^T_k fφ(x) + b_k, where w_k = 2c_k and b_k = −c^T_k c_k 2ckTfφ(x)ckTck=wkTfφ(x)+bk,wherewk=2ckandbk=ckTck

    与匹配网络比较:原型网络与匹配网络在少数镜头情况下不同,在单镜头场景中具有等效性。
    设计选择:Distance metrics, Episode composition

    二、视频

    在这里插入图片描述
    先讲了聚类算法是怎么进行的。
    在这里插入图片描述
    然后讲了prototype的运行方式。对一个3-way 5-shot任务来说,他有五个支持图片,每个支持图片进入到一个编码器生成zi,这些zi做平均mean得到ci。3 way一次有三类。查询图像经过相同的编码器得到za,计算与这三个zi的距离,经过softmax函数,得到属于每一个类的概率。
    在这里插入图片描述
    然后对loss的计算过程进行梳理。
    在这里插入图片描述
    在这里插入图片描述
    这里是伪代码。我感觉主要有两个步骤一个根据支持点得到z_proto(详细一点就是让所有输入通过网络得到z,根据每一类的支持点的z取平均得到每一类的z_proto),第二步计算距离,得到loss(用距离函数计算z_query和z_proto的距离,使用softmax函数得到x_query属于每一类的概率,然后根据query的标签计算loss)。
    在这里插入图片描述
    列出了prototype Networks的优缺点。


    总结

    原型网络的简单性和有效性使其成为少镜头学习的有前途的方法。

  • 相关阅读:
    Python 如何使用装饰器(decorators)
    java中集合的List
    【6. 操作系统—虚拟内存管理技术页面置换算法】
    android 禁止拖动桌面时钟小组件
    小猪APP分发:一站式托管服务,轻松玩转应用市场
    【思维构造】Effects of Anti Pimples—CF1877D
    vue3渲染函数(h函数)的变化
    无硫防静电手指套:高科技产业的纯净与安全新选择
    关于接口|常见电商API接口种类|接口数据类型|接口请求方法
    【OpenCV实现图像阈值处理】
  • 原文地址:https://blog.csdn.net/goodenough5/article/details/133338665