前记:
实体识别是信息抽取领域中比较重要的任务,其在学术界和工业界都是有很广泛的应用前景。但是当前实体识别任务强依赖于大量精细标注的数据,导致很难适应于快速迭代与实际业务快速发展的脚步。为了能够快速地在某个新的领域知识内,使用非常少的标注数据来达到更好效果,尤其是在学术界成为当前比较热门的话题。总的来说,我们引入新的研究课题——小样本实体识别(Few-shot Named Entity Recognition)。
本文介绍一个近期的有关few-shot NER的benchmark工作,引出基于元学习的NER数据集及baseline。
核心要点:
简要信息:
| 序号 | 属性 | 值 |
|---|---|---|
| 1 | 数据集名称 | Few-NERD |
| 2 | 发表位置 | ACL2021 |
| 3 | 所属领域 | 自然语言处理、信息抽取 |
| 4 | 研究内容 | 基于小样本学习的实体识别 |
| 5 | 核心内容 | Few-shot Learning,NER,Prototypical Learning,Metric Learning |
| 6 | GitHub源码 | https://ningding97.github.io/fewnerd/ |
| 7 | 论文PDF | https://aclanthology.org/2021.acl-long.248.pdf |
对于传统的分类任务,基本上是基于句子进行分类,因此定义小样本则可以使用基于episode的 N N N-way K K K-shot规则。即一个小样本学习过程(episode)只有 N N N个类别,每个类别下只有 K K K 个句子。
我们回顾一下传统的episode训练过程,我们以Prototype Network(原型网络)为例:
- 首先随机采样一个episode data。 每个episode data包含 support set(支持集)和query set (查询集)。对于support set而言,其包含 N N N个类别,每个类别下只有 K K K 个句子,即每个support set只有 N × K N\times K N×K 个样本。而对于query set则比较随意,可以只有1个句子,也可以有多个句子,也可以遵循 N N N-way K K K-shot 规则;
- 对于每个episode data,在support set上获得原型向量。对于每个类的 K K K 个句子,获得其句子表征后,对所有句子进行平均后即可得到当前类的原型向量(prototype)。因此support set可以得到 N N N 个prototype。
- 对于query set里的每一个样本,根据其标注的类别,计算分类损失。因为在训练时,每个query example是有标签的,所以获得query句子的表征向量后,与 N N N 个prototype计算距离作为预测的logit,并使用交叉信息熵作为目标函数。
在测试阶段,此时我们只有有标注的support set和无标注的query set,此时执行模型推理,先获得support set的prototype,再对每个query计算其与各个prototype的距离,并取最近的作为预测结果。
但是不同于分类,NER是基于token的分类,其旨在对每个token进行序列标注,因此无法直接使用传统的
N
N
N-way
K
K
K-shot 规则。因此本文重新定义了episode规则。Few-shot NER定义
挑战:
However, in the sequence labeling problem like NER, a sentence may contain multiple entities from different classes. And it is imperative(至关重要) to sample examples in sentence-level since contextual information is crucial for sequence labeling problems, especially for NER. Thus the sampling is more difficult than conventional classification tasks like relation extraction.
For example, when it comes to a 5-way 5-shot setting, if the support set already had 4 classes with 5 exam ples and 1 class with 4 examples, the next sampled sentence must only contain the specific one entity to strictly meet the requirement of 5 way 5 shot
例如,N=5时,K=5,则每个类的样本数量可以是8,9,5,7,5。且所有样本涉及到的所有实体只能是这5个类。
为了满足这个规则,我们实现了采样算法,如下图所示:

基于Few-shot NER的任务定义,我们进行人工打标+机器处理的过程获得了新的数据集Few-NERD。具体体现在如下三点:
Few-NERD数据分布情况如下:

数据集中包含了18万余句子460万余token(分词),近50万个实体,类别数量则有66个。而现有的其他数据集本质并非是小样本场景,因此Few-NERD是首个为few-shot量身定制的数据集。
为了验证构建的数据集是有效的,进行了一些简单的实验,如下图所示:

对66个细粒度实体类进行两两相似度的计算,发现同属于一个粗粒度实体类的所有细粒度实体类更加相似,具备transfer能力,而不同粗粒度实体之间差异较大。
基于构建好的数据集,我们提出三种基准,分别是监督模式、Intra模式和Inter模式:
作者提供了预处理好的三种基准数据集,对应的数据分布情况如图所示:

其中,一个episode data的数据格式如下所示:
{
"support": {
"word": [
["averostra", ",", "or", "``", "bird", "snouts", "''", ",", "is", "a", "clade", "that", "includes", "most", "theropod", "dinosaurs", "that", "have", "a", "promaxillary", "fenestra", "(", "``", "fenestra", "promaxillaris", "``", ")", ",", "an", "extra", "opening", "in", "the", "front", "outer", "side", "of", "the", "maxilla", ",", "the", "bone", "that", "makes", "up", "the", "upper", "jaw", "."],
["since", "that", "time", ",", "the", "squadron", "made", "several", "extended", "indian", "ocean", ",", "mediterranean", "sea", ",", "and", "north", "atlantic", "deployments", "as", "part", "of", "cvw-1", "/", "cv-66", ",", "until", "the", "decommissioning", "of", "uss", "``", "america", "''", "in", "1996", "."],
...
],
"label": [
["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "other-biologything", "other-biologything", "O", "O", "other-biologything", "other-biologything", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "other-biologything", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"],
["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "product-ship", "O", "product-ship", "O", "O", "O", "O", "O", "product-ship", "product-ship", "product-ship", "product-ship", "O", "O", "O"],
...
]
},
"query": {
"word": [["the", "final", "significant", "change", "in", "the", "life", "of", "the", "coco", "2", "(", "models", "26-3134b", ",", "26-3136b", ",", "and", "26-3127b", ";", "16", "kb", "standard", ",", "16", "kb", "extended", ",", "and", "64", "kb", "extended", "respectively", ")", "was", "to", "use", "the", "enhanced", "vdg", ",", "the", "mc6847t1", ",", "allowing", "lowercase", "characters", "and", "changing", "the", "text", "screen", "border", "color", "."],
...
],
"label": [["O", "O", "O", "O", "O", "O", "O", "O", "O", "product-software", "product-software", "O", "O", "product-software", "O", "product-software", "O", "O", "product-software", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "product-software", "O", "O", "product-software", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"],
...
]
},
"types": ["other-biologything", "building-airport", "location-island", "product-ship", "product-software"]
}
我们暂时只关注Few-shot场景,因为最终依然遵循类似 N N N-way K K K-shot 的规则,因此可以使用度量学习的方法,例如Prototypical Learning。具体细节如下:
作者开源了数据集以及代码框架,详见https://ningding97.github.io/fewnerd/。我们对核心代码进行了分析:
(1)train_demo.py:运行的主文件
(2)data_loader.py:数据处理与加载
默认情况下,读取原始的数据集后,根据采样算法,随机采样符合
N
N
N-way
K
K
K~
2
K
2K
2K-shot 规则的episode data。采样代码如下所示:
class FewShotNERDatasetWithRandomSampling(data.Dataset):
"""
Fewshot NER Dataset
"""
def __init__(self, filepath, tokenizer, N, K, Q, max_length, ignore_label_id=-1):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert(0)
self.class2sampleid = {} # 每个entity type class涉及到的样本标号
self.N = N
self.K = K
self.Q = Q
self.tokenizer = tokenizer
self.samples, self.classes = self.__load_data_from_file__(filepath) # 获取当前数据集所有样本和类
self.max_length = max_length
self.sampler = FewshotSampler(N, K, Q, self.samples, classes=self.classes) # 用于采样出一个episode任务
self.ignore_label_id = ignore_label_id
def __insert_sample__(self, index, sample_classes):
for item in sample_classes:
if item in self.class2sampleid:
self.class2sampleid[item].append(index)
else:
self.class2sampleid[item] = [index]
def __load_data_from_file__(self, filepath):
# 从本地加载数据集
samples = [] # 所有样本
classes = [] # 所有涉及的entity type class
with open(filepath, 'r', encoding='utf-8')as f:
lines = f.readlines()
samplelines = []
index = 0 # 当前样本编号
for line in lines:
line = line.strip()
if line:
samplelines.append(line)
else:
sample = Sample(samplelines)
samples.append(sample)
sample_classes = sample.get_tag_class() # 获得该样本中所有的entity type class
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
if samplelines: # 处理文件最后一个样本
sample = Sample(samplelines)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
classes = list(set(classes))
return samples, classes
def __get_token_label_list__(self, sample):
tokens = []
labels = []
for word, tag in zip(sample.words, sample.normalized_tags):
word_tokens = self.tokenizer.tokenize(word)
if word_tokens:
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1)
labels.extend(word_labels)
return tokens, labels
def __getraw__(self, tokens, labels):
# 分词、获得input_id,attention mask和segment id
# get tokenized word list, attention mask, text mask (mask [CLS], [SEP] as well), tags
# split into chunks of length (max_length-2)
# 2 is for special tokens [CLS] and [SEP]
tokens_list = []
labels_list = []
while len(tokens) > self.max_length - 2:
tokens_list.append(tokens[:self.max_length-2])
tokens = tokens[self.max_length-2:]
labels_list.append(labels[:self.max_length-2])
labels = labels[self.max_length-2:]
if tokens:
tokens_list.append(tokens)
labels_list.append(labels)
# add special tokens and get masks
indexed_tokens_list = []
mask_list = []
text_mask_list = []
for i, tokens in enumerate(tokens_list):
# token -> ids
tokens = ['[CLS]'] + tokens + ['[SEP]']
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
# padding
while len(indexed_tokens) < self.max_length:
indexed_tokens.append(0)
indexed_tokens_list.append(indexed_tokens)
# mask
mask = np.zeros((self.max_length), dtype=np.int32)
mask[:len(tokens)] = 1
mask_list.append(mask)
# text mask, also mask [CLS] and [SEP]
text_mask = np.zeros((self.max_length), dtype=np.int32)
text_mask[1:len(tokens)-1] = 1
text_mask_list.append(text_mask)
assert len(labels_list[i]) == len(tokens) - 2, print(labels_list[i], tokens)
return indexed_tokens_list, mask_list, text_mask_list, labels_list
def __additem__(self, index, d, word, mask, text_mask, label):
d['index'].append(index)
d['word'] += word
d['mask'] += mask
d['label'] += label
d['text_mask'] += text_mask
def __populate__(self, idx_list, savelabeldic=False):
'''
populate samples into data dict
set savelabeldic=True if you want to save label2tag dict
'index': sample_index
'word': tokenized word ids
'mask': attention mask in BERT
'label': NER labels
'sentence_num': number of sentences in this set (a batch contains multiple sets)
'text_mask': 0 for special tokens and paddings, 1 for real text
'''
dataset = {'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
for idx in idx_list:
tokens, labels = self.__get_token_label_list__(self.samples[idx])
word, mask, text_mask, label = self.__getraw__(tokens, labels) # BERT分词、生成input_ids,attention_mask...
word = torch.tensor(word).long()
mask = torch.tensor(np.array(mask)).long()
text_mask = torch.tensor(np.array(text_mask)).long()
self.__additem__(idx, dataset, word, mask, text_mask, label)
dataset['sentence_num'] = [len(dataset['word'])]
if savelabeldic:
dataset['label2tag'] = [self.label2tag]
return dataset
def __getitem__(self, index):
# 每次获得一个新数据。一个item表示一个episode任务数据
target_classes, support_idx, query_idx = self.sampler.__next__() # Sampler采样一组episode任务数据
# add 'O' and make sure 'O' is labeled 0
distinct_tags = ['O'] + target_classes
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
# support_set (类似input features):{'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
support_set = self.__populate__(support_idx) # 根据采样得到的样本编号,生成数据(input_id, attention_mask等)
query_set = self.__populate__(query_idx, savelabeldic=True)
return support_set, query_set
def __len__(self):
return 100000
由于每次采样的结果是不一样的,为了公平地对比baseline,作者也提供了已经预处理好的episode data,该数据在后续的research中被用于公平对比。直接读取预处理好的episode data代码如下:
class FewShotNERDataset(FewShotNERDatasetWithRandomSampling):
def __init__(self, filepath, tokenizer, max_length, ignore_label_id=-1):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert(0)
self.class2sampleid = {}
self.tokenizer = tokenizer
self.samples = self.__load_data_from_file__(filepath)
self.max_length = max_length
self.ignore_label_id = ignore_label_id
def __load_data_from_file__(self, filepath):
with open(filepath)as f:
lines = f.readlines()
for i in range(len(lines)):
lines[i] = json.loads(lines[i].strip())
return lines
def __additem__(self, d, word, mask, text_mask, label):
d['word'] += word
d['mask'] += mask
d['label'] += label
d['text_mask'] += text_mask
def __get_token_label_list__(self, words, tags):
tokens = []
labels = []
for word, tag in zip(words, tags):
word_tokens = self.tokenizer.tokenize(word)
if word_tokens:
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1)
labels.extend(word_labels)
return tokens, labels
def __populate__(self, data, savelabeldic=False):
'''
populate samples into data dict
set savelabeldic=True if you want to save label2tag dict
'word': tokenized word ids
'mask': attention mask in BERT
'label': NER labels
'sentence_num': number of sentences in this set (a batch contains multiple sets)
'text_mask': 0 for special tokens and paddings, 1 for real text
'''
dataset = {'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'text_mask':[] }
for i in range(len(data['word'])):
tokens, labels = self.__get_token_label_list__(data['word'][i], data['label'][i])
word, mask, text_mask, label = self.__getraw__(tokens, labels)
word = torch.tensor(word).long()
mask = torch.tensor(mask).long()
text_mask = torch.tensor(text_mask).long()
self.__additem__(dataset, word, mask, text_mask, label)
dataset['sentence_num'] = [len(dataset['word'])]
if savelabeldic:
dataset['label2tag'] = [self.label2tag]
return dataset
def __getitem__(self, index):
sample = self.samples[index]
target_classes = sample['types']
support = sample['support']
query = sample['query']
# add 'O' and make sure 'O' is labeled 0
distinct_tags = ['O'] + target_classes
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
support_set = self.__populate__(support)
query_set = self.__populate__(query, savelabeldic=True)
return support_set, query_set
def __len__(self):
return len(self.samples)
(3)Collator:数据装载回调函数
其主要将读取的数据转换为模型的输入,例如我们非常常见的字典翻转(详见下面代码注释)。
def collate_fn(data):
'''
dataloader会生成一个batch,对一个batch内的数据进行处理
一个batch内原始数据按照list({'word': [..], ..}, ...)存储
因此需要转换为{'word': [[..]. ..], ..}
e.g [{'word': [1, 2, 3]}, {'word': [4, 5, 6]}]
->
{'word': [[1, 2, 3], [4, 5, 6]]}
'''
batch_support = {'word': [], 'mask': [], 'label': [], 'sentence_num':[], 'text_mask':[]}
batch_query = {'word': [], 'mask': [], 'label': [], 'sentence_num':[], 'label2tag':[], 'text_mask':[]}
support_sets, query_sets = zip(*data)
for i in range(len(support_sets)):
for k in batch_support:
batch_support[k] += support_sets[i][k]
for k in batch_query:
batch_query[k] += query_sets[i][k]
for k in batch_support:
if k != 'label' and k != 'sentence_num':
batch_support[k] = torch.stack(batch_support[k], 0)
for k in batch_query:
if k !='label' and k != 'sentence_num' and k!= 'label2tag':
batch_query[k] = torch.stack(batch_query[k], 0)
batch_support['label'] = [torch.tensor(tag_list).long() for tag_list in batch_support['label']]
batch_query['label'] = [torch.tensor(tag_list).long() for tag_list in batch_query['label']]
return batch_support, batch_query
这里需要强调一点,对于
N
N
N-way
K
K
K-shot 训练模式下的batch size,不再是传统的句子数量,而是episode data数量,因此这里需要额外添加一个变量用于定位episode的位置,即sentence_num。具体细节描述如下:
sentence_num变量则记录着第
i
i
i 个episode的句子数量,在后期可逐个检索到相应的episode。例如对于support,输入2个batch,句子数量分别为5和7。那么support[word]包含12个句子,support[sentence_num]包含两个元素,分别为5和7。在后续计算prototype的时候,需要单独提取出每个episode对应的句子。
(4)Encoder
Few-NERD以及后续对比的方法,均采用BERT-base-uncased模型,下载地址为https://huggingface.co/bert-base-uncased。代码如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import os
from torch import optim
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, RobertaModel, RobertaTokenizer, RobertaForSequenceClassification
class BERTWordEncoder(nn.Module):
def __init__(self, pretrain_path):
nn.Module.__init__(self)
self.bert = BertModel.from_pretrained(pretrain_path)
def forward(self, words, masks):
outputs = self.bert(words, attention_mask=masks, output_hidden_states=True, return_dict=True)
#outputs = self.bert(inputs['word'], attention_mask=inputs['mask'], output_hidden_states=True, return_dict=True)
# use the sum of the last 4 layers
last_four_hidden_states = torch.cat([hidden_state.unsqueeze(0) for hidden_state in outputs['hidden_states'][-4:]], 0)
del outputs
word_embeddings = torch.sum(last_four_hidden_states, 0) # [num_sent, number_of_tokens, 768]
return word_embeddings
(5)模型主体结构
以Prototype Network为例,代码如下所示,可参考详细的代码注释理解。
import sys
sys.path.append('..')
import util
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
from torch.nn import functional as F
class Proto(util.framework.FewShotNERModel):
def __init__(self, word_encoder, dot=False, ignore_index=-1):
util.framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index)
self.drop = nn.Dropout()
self.dot = dot
def __dist__(self, x, y, dim):
if self.dot:
return (x * y).sum(dim)
else:
return -(torch.pow(x - y, 2)).sum(dim)
def __batch_dist__(self, S, Q, q_mask):
# S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
assert Q.size()[:2] == q_mask.size()
Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
def __get_proto__(self, embedding, tag, mask):
proto = []
embedding = embedding[mask==1].view(-1, embedding.size(-1))
tag = torch.cat(tag, 0)
assert tag.size(0) == embedding.size(0)
for label in range(torch.max(tag)+1):
proto.append(torch.mean(embedding[tag==label], 0))
proto = torch.stack(proto)
return proto, embedding
def forward(self, support, query):
'''
support: Inputs of the support set.
query: Inputs of the query set.
N: Num of classes
K: Num of instances for each class in the support set
Q: Num of instances in the query set
support/query = {'index': [], 'word': [], 'mask': [], 'label': [], 'sentence_num': [], 'text_mask': []}
'''
# support set和query set分别喂入BERT中获得各个样本的表示
support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768]
query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768]
support_emb = self.drop(support_emb)
query_emb = self.drop(query_emb)
# Prototypical Networks
logits = []
current_support_num = 0
current_query_num = 0
assert support_emb.size()[:2] == support['mask'].size()
assert query_emb.size()[:2] == query['mask'].size()
for i, sent_support_num in enumerate(support['sentence_num']): # 遍历每个采样得到的N-way K-shot任务数据
sent_query_num = query['sentence_num'][i]
# Calculate prototype for each class
# 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
# 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
support_proto, embedding = self.__get_proto__(
support_emb[current_support_num:current_support_num+sent_support_num],
support['label'][current_support_num:current_support_num+sent_support_num],
support['text_mask'][current_support_num: current_support_num+sent_support_num])
# calculate distance to each prototype
logits.append(self.__batch_dist__(
support_proto,
query_emb[current_query_num:current_query_num+sent_query_num],
query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
current_query_num += sent_query_num
current_support_num += sent_support_num
logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
_, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果
return logits, pred, embedding
作者还实现了NN-Shot和Struct-Shot,可具体参考原文与GitHub。
截止目前(2022年6月28日),已有多篇工作在EMNLP2021、AAAI、ACL2022上开始使用该数据集进行评测,目前的实验对比情况可详情:paperwithcode-INTRA和paperwithcode-INTER。目前的对比情况如图所示:

Few-NERD是比较新的评测任务,在其被提出之前,Few-shot NER基本是从几个热门的监督数据上采样构造成few-shot数据,不同人采用不同的构建方法使得模型之间的对比并不公平,而Few-NERD则提供了较为公平的评测基准,同时引出了few-shot在NER上的采样规则。
不过现有的这些评测方法和提出的模型依然存在一些问题: