• 【小样本实体识别】Few-NERD——基于N-way K-shot的实体识别数据集和方法介绍


    【小样本实体识别】Few-NERD——基于N-way K-shot的实体识别数据集和方法介绍

    前记:
      实体识别是信息抽取领域中比较重要的任务,其在学术界和工业界都是有很广泛的应用前景。但是当前实体识别任务强依赖于大量精细标注的数据,导致很难适应于快速迭代与实际业务快速发展的脚步。为了能够快速地在某个新的领域知识内,使用非常少的标注数据来达到更好效果,尤其是在学术界成为当前比较热门的话题。总的来说,我们引入新的研究课题——小样本实体识别(Few-shot Named Entity Recognition)

      本文介绍一个近期的有关few-shot NER的benchmark工作,引出基于元学习的NER数据集及baseline。

    核心要点:

    • 提出一种新的Few-shot NER数据集和benchmark,在语料数量、实体数量、实体类型数量上比现有数据集更有优势;
    • 针对实体类型,划分了粗粒度和细粒度,并提出 N N N-way K K K 2 K 2K 2K-shot 的Few-shot NER划分方案;

    简要信息:

    序号属性
    1数据集名称Few-NERD
    2发表位置ACL2021
    3所属领域自然语言处理、信息抽取
    4研究内容基于小样本学习的实体识别
    5核心内容Few-shot Learning,NER,Prototypical Learning,Metric Learning
    6GitHub源码https://ningding97.github.io/fewnerd/
    7论文PDFhttps://aclanthology.org/2021.acl-long.248.pdf

    一、动机

    • 深度学习模型(传统的神经网络、预训练语言模型等)在标注数据充足的情况下可以在full-supervised NER上达到较好效果;
    • 但是我们认为few-shot NER更切合实际场景,即unseen entity type只有少量样本;目前市面上缺乏这种专门为few-shot NER设计的数据集;
    • 先前的一些数据集(例如OntoNotes、CoNLL’03、WNUT’17等)粗粒度的实体类型数量过少,而且真实场景下unseen entity均为细粒度的;

    二、任务定义

      对于传统的分类任务,基本上是基于句子进行分类,因此定义小样本则可以使用基于episode的 N N N-way K K K-shot规则。即一个小样本学习过程(episode)只有 N N N个类别,每个类别下只有 K K K 个句子。

    我们回顾一下传统的episode训练过程,我们以Prototype Network(原型网络)为例:

    • 首先随机采样一个episode data。 每个episode data包含 support set(支持集)和query set (查询集)。对于support set而言,其包含 N N N个类别,每个类别下只有 K K K 个句子,即每个support set只有 N × K N\times K N×K 个样本。而对于query set则比较随意,可以只有1个句子,也可以有多个句子,也可以遵循 N N N-way K K K-shot 规则;
    • 对于每个episode data,在support set上获得原型向量。对于每个类的 K K K 个句子,获得其句子表征后,对所有句子进行平均后即可得到当前类的原型向量(prototype)。因此support set可以得到 N N N 个prototype。
    • 对于query set里的每一个样本,根据其标注的类别,计算分类损失。因为在训练时,每个query example是有标签的,所以获得query句子的表征向量后,与 N N N 个prototype计算距离作为预测的logit,并使用交叉信息熵作为目标函数。

    在测试阶段,此时我们只有有标注的support set和无标注的query set,此时执行模型推理,先获得support set的prototype,再对每个query计算其与各个prototype的距离,并取最近的作为预测结果。

      但是不同于分类,NER是基于token的分类,其旨在对每个token进行序列标注,因此无法直接使用传统的 N N N-way K K K-shot 规则。因此本文重新定义了episode规则。Few-shot NER定义
    挑战:

    • NER是在token-level级别的分类,而不是sentence-level的分类,而且每个句子可能包含很多类型的实体。但是在划分时必须是以sentence为主,因为不同句子的语义会影响实体的类型

    However, in the sequence labeling problem like NER, a sentence may contain multiple entities from different classes. And it is imperative(至关重要) to sample examples in sentence-level since contextual information is crucial for sequence labeling problems, especially for NER. Thus the sampling is more difficult than conventional classification tasks like relation extraction.

    • 例如5-way 5-shot,则必须确保这5个样本中只能包含5个类型的实体,这是很难采样这么精准的,换句话说,无法确保随机采样得到的一个episode data包含5个不同类别,且每个类别正好5个实体。

    For example, when it comes to a 5-way 5-shot setting, if the support set already had 4 classes with 5 exam ples and 1 class with 4 examples, the next sampled sentence must only contain the specific one entity to strictly meet the requirement of 5 way 5 shot

    • 因此,本文提出 N N N-way K K K 2 K 2K 2K-shot 规则,即给定的 N N N 个entity type class,只要确保采样的样本数量在 K K K 2 K 2K 2K 之间即可,相当于放宽了对每个类别对应实体数量的强行约束。

    例如,N=5时,K=5,则每个类的样本数量可以是8,9,5,7,5。且所有样本涉及到的所有实体只能是这5个类。

      为了满足这个规则,我们实现了采样算法,如下图所示:
    https://img-blog.csdnimg.cn/13919b282bf547e6a05dc3376a2d7d97.png

    三、Few-NERD——大规模多粒度小样本实体识别评测基准

      基于Few-shot NER的任务定义,我们进行人工打标+机器处理的过程获得了新的数据集Few-NERD。具体体现在如下三点:

    • 参考FIGER,粗粒度实体类型有8个。细粒度有66个
    • 语料来自于Wikipedia English dumps,对于每个细粒度实体类,随机挑选1000个paragraph进行人工标注,每个paragraph平均包含61.3个tokens
    • 邀请70个标注者和10位专家进行标注和检查;

      Few-NERD数据分布情况如下:
    在这里插入图片描述
    数据集中包含了18万余句子460万余token(分词),近50万个实体,类别数量则有66个。而现有的其他数据集本质并非是小样本场景,因此Few-NERD是首个为few-shot量身定制的数据集。

      为了验证构建的数据集是有效的,进行了一些简单的实验,如下图所示:
    在这里插入图片描述
    对66个细粒度实体类进行两两相似度的计算,发现同属于一个粗粒度实体类的所有细粒度实体类更加相似,具备transfer能力,而不同粗粒度实体之间差异较大

      基于构建好的数据集,我们提出三种基准,分别是监督模式、Intra模式和Inter模式:

    • Few-NERD(SUP):标准的监督学习模式,随机对所有语料进行采样。70%作为训练集,10%作为验证集,20%作为测试集,三个集合均都包含66个细粒度实体类;
    • Few-shot NER:根据实体类型划分训练集、验证集和测试集,确保每个数据集中只包含部分实体类,且各个数据集的实体类之间不存在交叉;具体的包括:
      (1)Few-NERD(INTRA):按照粒度的实体进行分类。例如训练集:People, MISC, Art, Product,验证集:Event, Building,测试集:ORG, LOC;由于不同粗粒度之间相关性很低,所以该任务具有挑战性;
      (2) Few-NERD(INTER):按照粒度进行划分。每个粗粒度类中,均随机挑选60%的细粒度实体类作为训练集,同理,每个粗粒度类中随机挑选20%、20%作为验证集和测试集。该设定下,每个数据集都涉及到所有粗粒度实体类,而需要考察细粒度实体类之间的泛化性能。

      作者提供了预处理好的三种基准数据集,对应的数据分布情况如图所示:
    在这里插入图片描述
    其中,一个episode data的数据格式如下所示:

    {
    	"support": {
    		"word": [
    			["averostra", ",", "or", "``", "bird", "snouts", "''", ",", "is", "a", "clade", "that", "includes", "most", "theropod", "dinosaurs", "that", "have", "a", "promaxillary", "fenestra", "(", "``", "fenestra", "promaxillaris", "``", ")", ",", "an", "extra", "opening", "in", "the", "front", "outer", "side", "of", "the", "maxilla", ",", "the", "bone", "that", "makes", "up", "the", "upper", "jaw", "."], 
    			["since", "that", "time", ",", "the", "squadron", "made", "several", "extended", "indian", "ocean", ",", "mediterranean", "sea", ",", "and", "north", "atlantic", "deployments", "as", "part", "of", "cvw-1", "/", "cv-66", ",", "until", "the", "decommissioning", "of", "uss", "``", "america", "''", "in", "1996", "."], 
    			...
    			], 
    		"label": [
    			["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "other-biologything", "other-biologything", "O", "O", "other-biologything", "other-biologything", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "other-biologything", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"], 
    			["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "product-ship", "O", "product-ship", "O", "O", "O", "O", "O", "product-ship", "product-ship", "product-ship", "product-ship", "O", "O", "O"], 
    			...
    			]
    	}, 
    	"query": {
    		"word": [["the", "final", "significant", "change", "in", "the", "life", "of", "the", "coco", "2", "(", "models", "26-3134b", ",", "26-3136b", ",", "and", "26-3127b", ";", "16", "kb", "standard", ",", "16", "kb", "extended", ",", "and", "64", "kb", "extended", "respectively", ")", "was", "to", "use", "the", "enhanced", "vdg", ",", "the", "mc6847t1", ",", "allowing", "lowercase", "characters", "and", "changing", "the", "text", "screen", "border", "color", "."], 
    		...
    		], 
    		"label": [["O", "O", "O", "O", "O", "O", "O", "O", "O", "product-software", "product-software", "O", "O", "product-software", "O", "product-software", "O", "O", "product-software", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "product-software", "O", "O", "product-software", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"], 
    		...
    		]
    	}, 
    	"types": ["other-biologything", "building-airport", "location-island", "product-ship", "product-software"]
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    四、baseline及其代码

      我们暂时只关注Few-shot场景,因为最终依然遵循类似 N N N-way K K K-shot 的规则,因此可以使用度量学习的方法,例如Prototypical Learning。具体细节如下:

    • 采样获得一个episode数据,其包含support set和query set,均遵循 N N N-way K K K 2 K 2K 2K-shot 的规则;
    • 对于support set,句子的数量是不确定的,但是我们依然将所有句子喂入encoder里(例如BERT)获得每个句子每个token的表示向量;
    • 由于已经知道每个句子的各个token的标签(BIO规则) ,因此根据标签,将各个类别的token向量汇总起来,并计算prototype;
    • 对于query set的每个句子的每个token,计算其与各个prototype的距离,并使用交叉信息熵计算损失。

      作者开源了数据集以及代码框架,详见https://ningding97.github.io/fewnerd/。我们对核心代码进行了分析:

    (1)train_demo.py:运行的主文件
    (2)data_loader.py:数据处理与加载
      默认情况下,读取原始的数据集后,根据采样算法,随机采样符合 N N N-way K K K 2 K 2K 2K-shot 规则的episode data。采样代码如下所示:

    class FewShotNERDatasetWithRandomSampling(data.Dataset):
        """
        Fewshot NER Dataset
        """
        def __init__(self, filepath, tokenizer, N, K, Q, max_length, ignore_label_id=-1):
            if not os.path.exists(filepath):
                print("[ERROR] Data file does not exist!")
                assert(0)
            self.class2sampleid = {} # 每个entity type class涉及到的样本标号
            self.N = N
            self.K = K
            self.Q = Q
            self.tokenizer = tokenizer
            self.samples, self.classes = self.__load_data_from_file__(filepath) # 获取当前数据集所有样本和类
            self.max_length = max_length
            self.sampler = FewshotSampler(N, K, Q, self.samples, classes=self.classes) # 用于采样出一个episode任务
            self.ignore_label_id = ignore_label_id
    
        def __insert_sample__(self, index, sample_classes):
            for item in sample_classes:
                if item in self.class2sampleid:
                    self.class2sampleid[item].append(index)
                else:
                    self.class2sampleid[item] = [index]
        
        def __load_data_from_file__(self, filepath):
            # 从本地加载数据集
            samples = [] # 所有样本
            classes = [] # 所有涉及的entity type class
            with open(filepath, 'r', encoding='utf-8')as f:
                lines = f.readlines()
            samplelines = []
            index = 0 # 当前样本编号
            for line in lines:
                line = line.strip()
                if line:
                    samplelines.append(line)
                else:
                    sample = Sample(samplelines)
                    samples.append(sample)
                    sample_classes = sample.get_tag_class() # 获得该样本中所有的entity type class
                    self.__insert_sample__(index, sample_classes)
                    classes += sample_classes
                    samplelines = []
                    index += 1
            if samplelines: # 处理文件最后一个样本
                sample = Sample(samplelines)
                samples.append(sample)
                sample_classes = sample.get_tag_class()
                self.__insert_sample__(index, sample_classes)
                classes += sample_classes
                samplelines = []
                index += 1
            classes = list(set(classes))
            return samples, classes
    
        def __get_token_label_list__(self, sample):
            tokens = []
            labels = []
            for word, tag in zip(sample.words, sample.normalized_tags):
                word_tokens = self.tokenizer.tokenize(word)
                if word_tokens:
                    tokens.extend(word_tokens)
                    # Use the real label id for the first token of the word, and padding ids for the remaining tokens
                    word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1)
                    labels.extend(word_labels)
            return tokens, labels
    
    
        def __getraw__(self, tokens, labels):
            # 分词、获得input_id,attention mask和segment id
            # get tokenized word list, attention mask, text mask (mask [CLS], [SEP] as well), tags
            
            # split into chunks of length (max_length-2)
            # 2 is for special tokens [CLS] and [SEP]
            tokens_list = []
            labels_list = []
            while len(tokens) > self.max_length - 2:
                tokens_list.append(tokens[:self.max_length-2])
                tokens = tokens[self.max_length-2:]
                labels_list.append(labels[:self.max_length-2])
                labels = labels[self.max_length-2:]
            if tokens:
                tokens_list.append(tokens)
                labels_list.append(labels)
    
            # add special tokens and get masks
            indexed_tokens_list = []
            mask_list = []
            text_mask_list = []
            for i, tokens in enumerate(tokens_list):
                # token -> ids
                tokens = ['[CLS]'] + tokens + ['[SEP]']
                indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
            
                # padding
                while len(indexed_tokens) < self.max_length:
                    indexed_tokens.append(0)
                indexed_tokens_list.append(indexed_tokens)
    
                # mask
                mask = np.zeros((self.max_length), dtype=np.int32)
                mask[:len(tokens)] = 1
                mask_list.append(mask)
    
                # text mask, also mask [CLS] and [SEP]
                text_mask = np.zeros((self.max_length), dtype=np.int32)
                text_mask[1:len(tokens)-1] = 1
                text_mask_list.append(text_mask)
    
                assert len(labels_list[i]) == len(tokens) - 2, print(labels_list[i], tokens)
            return indexed_tokens_list, mask_list, text_mask_list, labels_list
    
        def __additem__(self, index, d, word, mask, text_mask, label):
            d['index'].append(index)
            d['word'] += word
            d['mask'] += mask
            d['label'] += label
            d['text_mask'] += text_mask
    
        def __populate__(self, idx_list, savelabeldic=False):
            '''
            populate samples into data dict
            set savelabeldic=True if you want to save label2tag dict
            'index': sample_index
            'word': tokenized word ids
            'mask': attention mask in BERT
            'label': NER labels
            'sentence_num': number of sentences in this set (a batch contains multiple sets)
            'text_mask': 0 for special tokens and paddings, 1 for real text
            '''
            dataset = {'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
            for idx in idx_list:
                tokens, labels = self.__get_token_label_list__(self.samples[idx])
                word, mask, text_mask, label = self.__getraw__(tokens, labels) # BERT分词、生成input_ids,attention_mask...
                word = torch.tensor(word).long()
                mask = torch.tensor(np.array(mask)).long()
                text_mask = torch.tensor(np.array(text_mask)).long()
                self.__additem__(idx, dataset, word, mask, text_mask, label)
            dataset['sentence_num'] = [len(dataset['word'])]
            if savelabeldic:
                dataset['label2tag'] = [self.label2tag]
            return dataset
    
        def __getitem__(self, index):
            # 每次获得一个新数据。一个item表示一个episode任务数据
            target_classes, support_idx, query_idx = self.sampler.__next__() # Sampler采样一组episode任务数据
            # add 'O' and make sure 'O' is labeled 0
            distinct_tags = ['O'] + target_classes
            self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
            self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
            # support_set (类似input features):{'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
            support_set = self.__populate__(support_idx) # 根据采样得到的样本编号,生成数据(input_id, attention_mask等)
            query_set = self.__populate__(query_idx, savelabeldic=True)
            return support_set, query_set
        
        def __len__(self):
            return 100000
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158

      由于每次采样的结果是不一样的,为了公平地对比baseline,作者也提供了已经预处理好的episode data,该数据在后续的research中被用于公平对比。直接读取预处理好的episode data代码如下:

    class FewShotNERDataset(FewShotNERDatasetWithRandomSampling):
        def __init__(self, filepath, tokenizer, max_length, ignore_label_id=-1):
            if not os.path.exists(filepath):
                print("[ERROR] Data file does not exist!")
                assert(0)
            self.class2sampleid = {}
            self.tokenizer = tokenizer
            self.samples = self.__load_data_from_file__(filepath)
            self.max_length = max_length
            self.ignore_label_id = ignore_label_id
        
        def __load_data_from_file__(self, filepath):
            with open(filepath)as f:
                lines = f.readlines()
            for i in range(len(lines)):
                lines[i] = json.loads(lines[i].strip())
            return lines
        
        def __additem__(self, d, word, mask, text_mask, label):
            d['word'] += word
            d['mask'] += mask
            d['label'] += label
            d['text_mask'] += text_mask
        
        def __get_token_label_list__(self, words, tags):
            tokens = []
            labels = []
            for word, tag in zip(words, tags):
                word_tokens = self.tokenizer.tokenize(word)
                if word_tokens:
                    tokens.extend(word_tokens)
                    # Use the real label id for the first token of the word, and padding ids for the remaining tokens
                    word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1)
                    labels.extend(word_labels)
            return tokens, labels
    
        def __populate__(self, data, savelabeldic=False):
            '''
            populate samples into data dict
            set savelabeldic=True if you want to save label2tag dict
            'word': tokenized word ids
            'mask': attention mask in BERT
            'label': NER labels
            'sentence_num': number of sentences in this set (a batch contains multiple sets)
            'text_mask': 0 for special tokens and paddings, 1 for real text
            '''
            dataset = {'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'text_mask':[] }
            for i in range(len(data['word'])):
                tokens, labels = self.__get_token_label_list__(data['word'][i], data['label'][i])
                word, mask, text_mask, label = self.__getraw__(tokens, labels)
                word = torch.tensor(word).long()
                mask = torch.tensor(mask).long()
                text_mask = torch.tensor(text_mask).long()
                self.__additem__(dataset, word, mask, text_mask, label)
            dataset['sentence_num'] = [len(dataset['word'])]
            if savelabeldic:
                dataset['label2tag'] = [self.label2tag]
            return dataset
    
        def __getitem__(self, index):
            sample = self.samples[index]
            target_classes = sample['types']
            support = sample['support']
            query = sample['query']
            # add 'O' and make sure 'O' is labeled 0
            distinct_tags = ['O'] + target_classes
            self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
            self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
            support_set = self.__populate__(support)
            query_set = self.__populate__(query, savelabeldic=True)
            return support_set, query_set
    
        def __len__(self):
            return len(self.samples)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75

    (3)Collator:数据装载回调函数
      其主要将读取的数据转换为模型的输入,例如我们非常常见的字典翻转(详见下面代码注释)。

    def collate_fn(data):
        '''
        dataloader会生成一个batch,对一个batch内的数据进行处理
        一个batch内原始数据按照list({'word': [..], ..}, ...)存储
        因此需要转换为{'word': [[..]. ..], ..}
    
        e.g [{'word': [1, 2, 3]}, {'word': [4, 5, 6]}]
        ->
        {'word': [[1, 2, 3], [4, 5, 6]]}
        '''
    
        batch_support = {'word': [], 'mask': [], 'label': [], 'sentence_num':[], 'text_mask':[]}
        batch_query = {'word': [], 'mask': [], 'label': [], 'sentence_num':[], 'label2tag':[], 'text_mask':[]}
        support_sets, query_sets = zip(*data)
    
        for i in range(len(support_sets)):
            for k in batch_support:
                batch_support[k] += support_sets[i][k]
            for k in batch_query:
                batch_query[k] += query_sets[i][k]
        for k in batch_support:
            if k != 'label' and k != 'sentence_num':
                batch_support[k] = torch.stack(batch_support[k], 0)
        for k in batch_query:
            if k !='label' and k != 'sentence_num' and k!= 'label2tag':
                batch_query[k] = torch.stack(batch_query[k], 0)
        batch_support['label'] = [torch.tensor(tag_list).long() for tag_list in batch_support['label']]
        batch_query['label'] = [torch.tensor(tag_list).long() for tag_list in batch_query['label']]
        return batch_support, batch_query
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

      这里需要强调一点,对于 N N N-way K K K-shot 训练模式下的batch size,不再是传统的句子数量,而是episode data数量,因此这里需要额外添加一个变量用于定位episode的位置,即sentence_num。具体细节描述如下:

    • 模型的输入部分,我们拆分为两个集合,分别是support和query,每个集合对应若干个句子,是直接将一个batch内所有episode data的support/query句子直接堆叠起来。
    • 为了知道哪些句子是一个episode,sentence_num变量则记录着第 i i i 个episode的句子数量,在后期可逐个检索到相应的episode。

    例如对于support,输入2个batch,句子数量分别为5和7。那么support[word]包含12个句子,support[sentence_num]包含两个元素,分别为5和7。在后续计算prototype的时候,需要单独提取出每个episode对应的句子。

    (4)Encoder
      Few-NERD以及后续对比的方法,均采用BERT-base-uncased模型,下载地址为https://huggingface.co/bert-base-uncased。代码如下所示:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    import numpy as np
    import os
    from torch import optim
    from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, RobertaModel, RobertaTokenizer, RobertaForSequenceClassification
    
    class BERTWordEncoder(nn.Module):
    
        def __init__(self, pretrain_path): 
            nn.Module.__init__(self)
            self.bert = BertModel.from_pretrained(pretrain_path)
    
        def forward(self, words, masks):
            outputs = self.bert(words, attention_mask=masks, output_hidden_states=True, return_dict=True)
            #outputs = self.bert(inputs['word'], attention_mask=inputs['mask'], output_hidden_states=True, return_dict=True)
            # use the sum of the last 4 layers
            last_four_hidden_states = torch.cat([hidden_state.unsqueeze(0) for hidden_state in outputs['hidden_states'][-4:]], 0)
            del outputs
            word_embeddings = torch.sum(last_four_hidden_states, 0) # [num_sent, number_of_tokens, 768]
            return word_embeddings
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    (5)模型主体结构
      以Prototype Network为例,代码如下所示,可参考详细的代码注释理解。

    import sys
    sys.path.append('..')
    import util
    import torch
    from torch import autograd, optim, nn
    from torch.autograd import Variable
    from torch.nn import functional as F
    
    class Proto(util.framework.FewShotNERModel):
        
        def __init__(self, word_encoder, dot=False, ignore_index=-1):
            util.framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index)
            self.drop = nn.Dropout()
            self.dot = dot
    
        def __dist__(self, x, y, dim):
            if self.dot:
                return (x * y).sum(dim)
            else:
                return -(torch.pow(x - y, 2)).sum(dim)
    
        def __batch_dist__(self, S, Q, q_mask):
            # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
            assert Q.size()[:2] == q_mask.size()
            Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
            return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
    
        def __get_proto__(self, embedding, tag, mask):
            proto = []
            embedding = embedding[mask==1].view(-1, embedding.size(-1))
            tag = torch.cat(tag, 0)
            assert tag.size(0) == embedding.size(0)
            for label in range(torch.max(tag)+1):
                proto.append(torch.mean(embedding[tag==label], 0))
            proto = torch.stack(proto)
            return proto, embedding
    
        def forward(self, support, query):
            '''
            support: Inputs of the support set.
            query: Inputs of the query set.
            N: Num of classes
            K: Num of instances for each class in the support set
            Q: Num of instances in the query set
    
            support/query = {'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
            '''
            # support set和query set分别喂入BERT中获得各个样本的表示
            support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768]
            query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768]
            support_emb = self.drop(support_emb)
            query_emb = self.drop(query_emb)
    
            # Prototypical Networks
            logits = []
            current_support_num = 0
            current_query_num = 0
            assert support_emb.size()[:2] == support['mask'].size()
            assert query_emb.size()[:2] == query['mask'].size()
    
            for i, sent_support_num in enumerate(support['sentence_num']): # 遍历每个采样得到的N-way K-shot任务数据
                sent_query_num = query['sentence_num'][i]
                # Calculate prototype for each class
                # 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
                # 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
                support_proto, embedding = self.__get_proto__(
                    support_emb[current_support_num:current_support_num+sent_support_num], 
                    support['label'][current_support_num:current_support_num+sent_support_num], 
                    support['text_mask'][current_support_num: current_support_num+sent_support_num])
                # calculate distance to each prototype
                logits.append(self.__batch_dist__(
                    support_proto, 
                    query_emb[current_query_num:current_query_num+sent_query_num],
                    query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
                current_query_num += sent_query_num
                current_support_num += sent_support_num
            logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
            _, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果
    
            return logits, pred, embedding
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80

      作者还实现了NN-Shot和Struct-Shot,可具体参考原文与GitHub。

    五、目前实验对比

      截止目前(2022年6月28日),已有多篇工作在EMNLP2021、AAAI、ACL2022上开始使用该数据集进行评测,目前的实验对比情况可详情:paperwithcode-INTRApaperwithcode-INTER。目前的对比情况如图所示:
    在这里插入图片描述

    六、总结

      Few-NERD是比较新的评测任务,在其被提出之前,Few-shot NER基本是从几个热门的监督数据上采样构造成few-shot数据,不同人采用不同的构建方法使得模型之间的对比并不公平,而Few-NERD则提供了较为公平的评测基准,同时引出了few-shot在NER上的采样规则。

      不过现有的这些评测方法和提出的模型依然存在一些问题:

    • “O”标签问题:因为Few-NERD依然是基于序列标注的数据,每个token给予“BIO”标签,因此对于一个句子,依然存在大量的“O”标签,这会对模型产生干扰。目前有相关工作解决该类问题;
    • token之间的label依赖:因为序列标注需要考虑到输出部分的依赖关系,例如B必须在实体的第一个位置。因此需要额外引入类似维特比算法。但是因为每个episode的类别不一样,无法直接使用CRF来预测。目前有一些工作尝试解决few-shot场景下的标签依赖问题。
  • 相关阅读:
    静态数码管显示+动态数码管显示——“51单片机”
    如何保护你的网络安全?
    探索ChatGPT在提高人脸识别与软性生物识准确性的表现与可解释性
    【Mysql】索引
    Unity3D XML与Properties配置文件读取详解
    学会使用这个平台,教你制作出色的产品画册?
    谈谈对Volatile的理解
    mysql的备份和恢复
    2023年 DevOps 七大趋势
    ubuntu18.04忘记密码后,如何重置密码的方法
  • 原文地址:https://blog.csdn.net/qq_36426650/article/details/125501070