我之前写过一些使用Huggingface的transformers生成embedding的方法,例如这篇:怎样通过预训练的Transformers的模型得到一个Sentence的Representation_蛐蛐蛐的博客-CSDN博客
但之前使用的时候,主要是生成一个sequence的embedding,很少涉及提取单个token的embedding。试了一下,发现比我想像得要复杂一些,所以总结一下。
由于往往会使用BPE等方法,导致输入到transformer encoder的序列和真实的输入序列并不等长,也没有什么严格的对应方式,所以要计算原序列中单个word的embedding,就需要使用其他API了。StackOverFlow上也有人讨论了这个问题:tokenize - Mapping huggingface tokens to original input text - Stack Overflow
但我仔细看了一下,发现已有解答并不是很正确,例如我要找到下面这个序列的对应关系:
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
- example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
- encoded = tokenizer(example)
- print(encoded['input_ids'])
- print(len(encoded['input_ids']))
- print(encoded.tokens())
- print(len(encoded.tokens()))
- print(encoded.word_ids())
- print(len(encoded.word_ids()))
输出结果是:
- [0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
- 26
- ['
', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', ''] - 26
- [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
- 26
上面我的输入是一串汇编代码,这个只是为了举例方便。简单来说,tokenizer(InputSentence)会返回一个BatchEncoding的对象,我们从这个对象的function中即可以得到需要的信息。从上面的输出中可以看到,encoded['input_ids']对应的就是输入到transformer encoder的tensor输入,.tokens()返回的是tokenize以后的token,.word_ids()返回的是tokenize以后的token的编码(而并不是输入sequence的)。那具体应该怎么样对应到原输入sequence呢?我尝试了很多种方法,发现只有利用token_to_chars这个function。具体的API文档可以参考这里:Tokenizer
例如,下面代码中我列举了各种情况:
- for token_index in range(len(encoded.tokens())):
- this_token=encoded.word_ids()[token_index]
- if(not this_token==None):
- print('###########################')
- print(token_index)
- print(encoded.token_to_chars(token_index))
- print(encoded.token_to_word(token_index))
- char_span=encoded.token_to_chars(token_index)
- print('...')
- for char_index in range(char_span.start,char_span.end):
- print(encoded.char_to_word(char_index))
- print(encoded.char_to_token(char_index))
- print('...')
- print('###########################')
对应的输出结果是:
- ###########################
- 1
- CharSpan(start=0, end=4)
- 0
- ...
- 0
- 1
- 0
- 1
- 0
- 1
- 0
- 1
- ...
- ###########################
- ###########################
- 2
- CharSpan(start=5, end=6)
- 1
- ...
- 1
- 2
- ...
- ###########################
- ###########################
- 3
- CharSpan(start=6, end=8)
- 2
- ...
- 2
- 3
- 2
- 3
- ...
- ###########################
- ###########################
- 4
- CharSpan(start=9, end=13)
- 3
- ...
- 3
- 4
- 3
- 4
- 3
- 4
- 3
- 4
- ...
- ###########################
- ###########################
- 5
- CharSpan(start=14, end=15)
- 4
- ...
- 4
- 5
- ...
- ###########################
- ###########################
- 6
- CharSpan(start=15, end=17)
- 5
- ...
- 5
- 6
- 5
- 6
- ...
- ###########################
- ###########################
- 7
- CharSpan(start=18, end=21)
- 6
- ...
- 6
- 7
- 6
- 7
- 6
- 7
- ...
- ###########################
- ###########################
- 8
- CharSpan(start=22, end=23)
- 7
- ...
- 7
- 8
- ...
- ###########################
- ###########################
- 9
- CharSpan(start=23, end=25)
- 8
- ...
- 8
- 9
- 8
- 9
- ...
- ###########################
- ###########################
- 10
- CharSpan(start=26, end=27)
- 9
- ...
- 9
- 10
- ...
- ###########################
- ###########################
- 11
- CharSpan(start=28, end=29)
- 10
- ...
- 10
- 11
- ...
- ###########################
- ###########################
- 12
- CharSpan(start=29, end=30)
- 11
- ...
- 11
- 12
- ...
- ###########################
- ###########################
- 13
- CharSpan(start=31, end=35)
- 12
- ...
- 12
- 13
- 12
- 13
- 12
- 13
- 12
- 13
- ...
- ###########################
- ###########################
- 14
- CharSpan(start=36, end=37)
- 13
- ...
- 13
- 14
- ...
- ###########################
- ###########################
- 15
- CharSpan(start=37, end=39)
- 14
- ...
- 14
- 15
- 14
- 15
- ...
- ###########################
- ###########################
- 16
- CharSpan(start=40, end=44)
- 15
- ...
- 15
- 16
- 15
- 16
- 15
- 16
- 15
- 16
- ...
- ###########################
- ###########################
- 17
- CharSpan(start=45, end=46)
- 16
- ...
- 16
- 17
- ...
- ###########################
- ###########################
- 18
- CharSpan(start=46, end=48)
- 17
- ...
- 17
- 18
- 17
- 18
- ...
- ###########################
- ###########################
- 19
- CharSpan(start=49, end=52)
- 18
- ...
- 18
- 19
- 18
- 19
- 18
- 19
- ...
- ###########################
- ###########################
- 20
- CharSpan(start=53, end=54)
- 19
- ...
- 19
- 20
- ...
- ###########################
- ###########################
- 21
- CharSpan(start=54, end=56)
- 20
- ...
- 20
- 21
- 20
- 21
- ...
- ###########################
- ###########################
- 22
- CharSpan(start=57, end=58)
- 21
- ...
- 21
- 22
- ...
- ###########################
- ###########################
- 23
- CharSpan(start=59, end=60)
- 22
- ...
- 22
- 23
- ...
- ###########################
- ###########################
- 24
- CharSpan(start=60, end=62)
- 22
- ...
- 22
- 24
- 22
- 24
- ...
- ###########################
可以看到,只有token_to_chars可以返回一个CharSpan,这个表明的是在原输入序列中的位置。其他的什么char_to_token就是在编码以后的序列上绕来扰去,这个和API文档里说的完全不一样啊。有了这个function,我们可以构造两个mapping,来获得原输入序列中的word和编码以后的token之间的对应关系:
- corpora_records=example.split(' ')
- word_2_char_mapping={}
- char_cursor=0
- for ind in range(len(corpora_records)):
- if(len(corpora_records[ind])>0):#the last space will not be considered
- start=char_cursor
- end=char_cursor+len(corpora_records[ind])
- word_2_char_mapping[ind]=[start,end]
- char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
-
- print(word_2_char_mapping)
-
- word_2_token_mapping={}
- for token_index in range(len(encoded.tokens())):
- this_token=encoded.word_ids()[token_index]
- if(not this_token==None):
- char_span=encoded.token_to_chars(token_index)
- for each_word in word_2_char_mapping:
- start=word_2_char_mapping[each_word][0]
- end=word_2_char_mapping[each_word][1]
- if(char_span.start>=start and char_span.end<=end):
- # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
- # print('--->')
- # print(corpora_records[each_word])
-
- if(each_word in word_2_token_mapping):
- word_2_token_mapping[each_word].append(token_index)
- else:
- word_2_token_mapping[each_word]=[token_index]
-
- print(word_2_token_mapping)
对应的输出是:
- [0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
- 26
- ['
', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', ''] - 26
- [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
- 26
- {0: [0, 4], 1: [5, 8], 2: [9, 13], 3: [14, 17], 4: [18, 21], 5: [22, 25], 6: [26, 27], 7: [28, 30], 8: [31, 35], 9: [36, 39], 10: [40, 44], 11: [45, 48], 12: [49, 52], 13: [53, 56], 14: [57, 58], 15: [59, 62]}
- {0: [1], 1: [2, 3], 2: [4], 3: [5, 6], 4: [7], 5: [8, 9], 6: [10], 7: [11, 12], 8: [13], 9: [14, 15], 10: [16], 11: [17, 18], 12: [19], 13: [20, 21], 14: [22], 15: [23, 24]}
为了展示方便,省略了之前很长的那些print,可以看到,最后生成的这个word_2_token_mapping是完全正确的。最后把所有代码贴一下,方便大家可以很快查看检验:
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
- example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
- encoded = tokenizer(example)
- print(encoded['input_ids'])
- print(len(encoded['input_ids']))
- print(encoded.tokens())
- print(len(encoded.tokens()))
- print(encoded.word_ids())
- print(len(encoded.word_ids()))
-
- # for token_index in range(len(encoded.tokens())):
- # this_token=encoded.word_ids()[token_index]
- # if(not this_token==None):
- # print('###########################')
- # print(token_index)
- # print(encoded.token_to_chars(token_index))
- # print(encoded.token_to_word(token_index))
- # char_span=encoded.token_to_chars(token_index)
- # print('...')
- # for char_index in range(char_span.start,char_span.end):
- # print(encoded.char_to_word(char_index))
- # print(encoded.char_to_token(char_index))
- # print('...')
- # print('###########################')
-
- corpora_records=example.split(' ')
- word_2_char_mapping={}
- char_cursor=0
- for ind in range(len(corpora_records)):
- if(len(corpora_records[ind])>0):#the last space will not be considered
- start=char_cursor
- end=char_cursor+len(corpora_records[ind])
- word_2_char_mapping[ind]=[start,end]
- char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
-
- print(word_2_char_mapping)
-
- word_2_token_mapping={}
- for token_index in range(len(encoded.tokens())):
- this_token=encoded.word_ids()[token_index]
- if(not this_token==None):
- char_span=encoded.token_to_chars(token_index)
- for each_word in word_2_char_mapping:
- start=word_2_char_mapping[each_word][0]
- end=word_2_char_mapping[each_word][1]
- if(char_span.start>=start and char_span.end<=end):
- # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
- # print('--->')
- # print(corpora_records[each_word])
-
- if(each_word in word_2_token_mapping):
- word_2_token_mapping[each_word].append(token_index)
- else:
- word_2_token_mapping[each_word]=[token_index]
-
- print(word_2_token_mapping)
那些Stackoverflow上的回答,经过我仔细检查测试,发现不太靠谱。不过我也是看上面回答才知道能使用token_to_chars这个function。就简单总结这么多吧。