• 将transformers的tokenizer处理之后(如BPE)的序列映射回输入序列


    我之前写过一些使用Huggingface的transformers生成embedding的方法,例如这篇:怎样通过预训练的Transformers的模型得到一个Sentence的Representation_蛐蛐蛐的博客-CSDN博客

    但之前使用的时候,主要是生成一个sequence的embedding,很少涉及提取单个token的embedding。试了一下,发现比我想像得要复杂一些,所以总结一下。

    由于往往会使用BPE等方法,导致输入到transformer encoder的序列和真实的输入序列并不等长,也没有什么严格的对应方式,所以要计算原序列中单个word的embedding,就需要使用其他API了。StackOverFlow上也有人讨论了这个问题:tokenize - Mapping huggingface tokens to original input text - Stack Overflow

    但我仔细看了一下,发现已有解答并不是很正确,例如我要找到下面这个序列的对应关系:

    1. from transformers import AutoTokenizer
    2. tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
    3. example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
    4. encoded = tokenizer(example)
    5. print(encoded['input_ids'])
    6. print(len(encoded['input_ids']))
    7. print(encoded.tokens())
    8. print(len(encoded.tokens()))
    9. print(encoded.word_ids())
    10. print(len(encoded.word_ids()))

    输出结果是:

    1. [0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
    2. 26
    3. ['', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', '']
    4. 26
    5. [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
    6. 26

    上面我的输入是一串汇编代码,这个只是为了举例方便。简单来说,tokenizer(InputSentence)会返回一个BatchEncoding的对象,我们从这个对象的function中即可以得到需要的信息。从上面的输出中可以看到,encoded['input_ids']对应的就是输入到transformer encoder的tensor输入,.tokens()返回的是tokenize以后的token,.word_ids()返回的是tokenize以后的token的编码(而并不是输入sequence的)。那具体应该怎么样对应到原输入sequence呢?我尝试了很多种方法,发现只有利用token_to_chars这个function。具体的API文档可以参考这里:Tokenizer

    例如,下面代码中我列举了各种情况:

    1. for token_index in range(len(encoded.tokens())):
    2. this_token=encoded.word_ids()[token_index]
    3. if(not this_token==None):
    4. print('###########################')
    5. print(token_index)
    6. print(encoded.token_to_chars(token_index))
    7. print(encoded.token_to_word(token_index))
    8. char_span=encoded.token_to_chars(token_index)
    9. print('...')
    10. for char_index in range(char_span.start,char_span.end):
    11. print(encoded.char_to_word(char_index))
    12. print(encoded.char_to_token(char_index))
    13. print('...')
    14. print('###########################')

    对应的输出结果是:

    1. ###########################
    2. 1
    3. CharSpan(start=0, end=4)
    4. 0
    5. ...
    6. 0
    7. 1
    8. 0
    9. 1
    10. 0
    11. 1
    12. 0
    13. 1
    14. ...
    15. ###########################
    16. ###########################
    17. 2
    18. CharSpan(start=5, end=6)
    19. 1
    20. ...
    21. 1
    22. 2
    23. ...
    24. ###########################
    25. ###########################
    26. 3
    27. CharSpan(start=6, end=8)
    28. 2
    29. ...
    30. 2
    31. 3
    32. 2
    33. 3
    34. ...
    35. ###########################
    36. ###########################
    37. 4
    38. CharSpan(start=9, end=13)
    39. 3
    40. ...
    41. 3
    42. 4
    43. 3
    44. 4
    45. 3
    46. 4
    47. 3
    48. 4
    49. ...
    50. ###########################
    51. ###########################
    52. 5
    53. CharSpan(start=14, end=15)
    54. 4
    55. ...
    56. 4
    57. 5
    58. ...
    59. ###########################
    60. ###########################
    61. 6
    62. CharSpan(start=15, end=17)
    63. 5
    64. ...
    65. 5
    66. 6
    67. 5
    68. 6
    69. ...
    70. ###########################
    71. ###########################
    72. 7
    73. CharSpan(start=18, end=21)
    74. 6
    75. ...
    76. 6
    77. 7
    78. 6
    79. 7
    80. 6
    81. 7
    82. ...
    83. ###########################
    84. ###########################
    85. 8
    86. CharSpan(start=22, end=23)
    87. 7
    88. ...
    89. 7
    90. 8
    91. ...
    92. ###########################
    93. ###########################
    94. 9
    95. CharSpan(start=23, end=25)
    96. 8
    97. ...
    98. 8
    99. 9
    100. 8
    101. 9
    102. ...
    103. ###########################
    104. ###########################
    105. 10
    106. CharSpan(start=26, end=27)
    107. 9
    108. ...
    109. 9
    110. 10
    111. ...
    112. ###########################
    113. ###########################
    114. 11
    115. CharSpan(start=28, end=29)
    116. 10
    117. ...
    118. 10
    119. 11
    120. ...
    121. ###########################
    122. ###########################
    123. 12
    124. CharSpan(start=29, end=30)
    125. 11
    126. ...
    127. 11
    128. 12
    129. ...
    130. ###########################
    131. ###########################
    132. 13
    133. CharSpan(start=31, end=35)
    134. 12
    135. ...
    136. 12
    137. 13
    138. 12
    139. 13
    140. 12
    141. 13
    142. 12
    143. 13
    144. ...
    145. ###########################
    146. ###########################
    147. 14
    148. CharSpan(start=36, end=37)
    149. 13
    150. ...
    151. 13
    152. 14
    153. ...
    154. ###########################
    155. ###########################
    156. 15
    157. CharSpan(start=37, end=39)
    158. 14
    159. ...
    160. 14
    161. 15
    162. 14
    163. 15
    164. ...
    165. ###########################
    166. ###########################
    167. 16
    168. CharSpan(start=40, end=44)
    169. 15
    170. ...
    171. 15
    172. 16
    173. 15
    174. 16
    175. 15
    176. 16
    177. 15
    178. 16
    179. ...
    180. ###########################
    181. ###########################
    182. 17
    183. CharSpan(start=45, end=46)
    184. 16
    185. ...
    186. 16
    187. 17
    188. ...
    189. ###########################
    190. ###########################
    191. 18
    192. CharSpan(start=46, end=48)
    193. 17
    194. ...
    195. 17
    196. 18
    197. 17
    198. 18
    199. ...
    200. ###########################
    201. ###########################
    202. 19
    203. CharSpan(start=49, end=52)
    204. 18
    205. ...
    206. 18
    207. 19
    208. 18
    209. 19
    210. 18
    211. 19
    212. ...
    213. ###########################
    214. ###########################
    215. 20
    216. CharSpan(start=53, end=54)
    217. 19
    218. ...
    219. 19
    220. 20
    221. ...
    222. ###########################
    223. ###########################
    224. 21
    225. CharSpan(start=54, end=56)
    226. 20
    227. ...
    228. 20
    229. 21
    230. 20
    231. 21
    232. ...
    233. ###########################
    234. ###########################
    235. 22
    236. CharSpan(start=57, end=58)
    237. 21
    238. ...
    239. 21
    240. 22
    241. ...
    242. ###########################
    243. ###########################
    244. 23
    245. CharSpan(start=59, end=60)
    246. 22
    247. ...
    248. 22
    249. 23
    250. ...
    251. ###########################
    252. ###########################
    253. 24
    254. CharSpan(start=60, end=62)
    255. 22
    256. ...
    257. 22
    258. 24
    259. 22
    260. 24
    261. ...
    262. ###########################

    可以看到,只有token_to_chars可以返回一个CharSpan,这个表明的是在原输入序列中的位置。其他的什么char_to_token就是在编码以后的序列上绕来扰去,这个和API文档里说的完全不一样啊。有了这个function,我们可以构造两个mapping,来获得原输入序列中的word和编码以后的token之间的对应关系:

    1. corpora_records=example.split(' ')
    2. word_2_char_mapping={}
    3. char_cursor=0
    4. for ind in range(len(corpora_records)):
    5. if(len(corpora_records[ind])>0):#the last space will not be considered
    6. start=char_cursor
    7. end=char_cursor+len(corpora_records[ind])
    8. word_2_char_mapping[ind]=[start,end]
    9. char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
    10. print(word_2_char_mapping)
    11. word_2_token_mapping={}
    12. for token_index in range(len(encoded.tokens())):
    13. this_token=encoded.word_ids()[token_index]
    14. if(not this_token==None):
    15. char_span=encoded.token_to_chars(token_index)
    16. for each_word in word_2_char_mapping:
    17. start=word_2_char_mapping[each_word][0]
    18. end=word_2_char_mapping[each_word][1]
    19. if(char_span.start>=start and char_span.end<=end):
    20. # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
    21. # print('--->')
    22. # print(corpora_records[each_word])
    23. if(each_word in word_2_token_mapping):
    24. word_2_token_mapping[each_word].append(token_index)
    25. else:
    26. word_2_token_mapping[each_word]=[token_index]
    27. print(word_2_token_mapping)

    对应的输出是:

    1. [0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
    2. 26
    3. ['', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', '']
    4. 26
    5. [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
    6. 26
    7. {0: [0, 4], 1: [5, 8], 2: [9, 13], 3: [14, 17], 4: [18, 21], 5: [22, 25], 6: [26, 27], 7: [28, 30], 8: [31, 35], 9: [36, 39], 10: [40, 44], 11: [45, 48], 12: [49, 52], 13: [53, 56], 14: [57, 58], 15: [59, 62]}
    8. {0: [1], 1: [2, 3], 2: [4], 3: [5, 6], 4: [7], 5: [8, 9], 6: [10], 7: [11, 12], 8: [13], 9: [14, 15], 10: [16], 11: [17, 18], 12: [19], 13: [20, 21], 14: [22], 15: [23, 24]}

    为了展示方便,省略了之前很长的那些print,可以看到,最后生成的这个word_2_token_mapping是完全正确的。最后把所有代码贴一下,方便大家可以很快查看检验:

    1. from transformers import AutoTokenizer
    2. tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
    3. example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
    4. encoded = tokenizer(example)
    5. print(encoded['input_ids'])
    6. print(len(encoded['input_ids']))
    7. print(encoded.tokens())
    8. print(len(encoded.tokens()))
    9. print(encoded.word_ids())
    10. print(len(encoded.word_ids()))
    11. # for token_index in range(len(encoded.tokens())):
    12. # this_token=encoded.word_ids()[token_index]
    13. # if(not this_token==None):
    14. # print('###########################')
    15. # print(token_index)
    16. # print(encoded.token_to_chars(token_index))
    17. # print(encoded.token_to_word(token_index))
    18. # char_span=encoded.token_to_chars(token_index)
    19. # print('...')
    20. # for char_index in range(char_span.start,char_span.end):
    21. # print(encoded.char_to_word(char_index))
    22. # print(encoded.char_to_token(char_index))
    23. # print('...')
    24. # print('###########################')
    25. corpora_records=example.split(' ')
    26. word_2_char_mapping={}
    27. char_cursor=0
    28. for ind in range(len(corpora_records)):
    29. if(len(corpora_records[ind])>0):#the last space will not be considered
    30. start=char_cursor
    31. end=char_cursor+len(corpora_records[ind])
    32. word_2_char_mapping[ind]=[start,end]
    33. char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
    34. print(word_2_char_mapping)
    35. word_2_token_mapping={}
    36. for token_index in range(len(encoded.tokens())):
    37. this_token=encoded.word_ids()[token_index]
    38. if(not this_token==None):
    39. char_span=encoded.token_to_chars(token_index)
    40. for each_word in word_2_char_mapping:
    41. start=word_2_char_mapping[each_word][0]
    42. end=word_2_char_mapping[each_word][1]
    43. if(char_span.start>=start and char_span.end<=end):
    44. # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
    45. # print('--->')
    46. # print(corpora_records[each_word])
    47. if(each_word in word_2_token_mapping):
    48. word_2_token_mapping[each_word].append(token_index)
    49. else:
    50. word_2_token_mapping[each_word]=[token_index]
    51. print(word_2_token_mapping)

    那些Stackoverflow上的回答,经过我仔细检查测试,发现不太靠谱。不过我也是看上面回答才知道能使用token_to_chars这个function。就简单总结这么多吧。

  • 相关阅读:
    python使用sqlalchemy模块创建MySQL数据库连接、并在数据表中插入新的数据
    JVM参数优化
    白皮书 |得帆云低代码aPaaS X OA全新解决方案,解锁数字化协作新境界
    一文拿捏对象内存布局及JMM(JAVA内存模型)
    爬虫实例——从mindat上爬取矿石图片
    数仓之范式
    2023 ICPC 网络赛 第一场 部分题解 (待完善)
    算法训练营day42|动态规划 part04:0-1背包 (01背包问题基础(两种解决方案)、LeetCode 416.分割等和子集)
    flutter实现左侧导航,拿走即用
    从 PDB$SEED 创建新 PDB
  • 原文地址:https://blog.csdn.net/qysh123/article/details/126438203