• CPM:A large-scale generative chinese pre-trained lanuage model


    GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成) - GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)https://github.com/yangjianxin1/CPM论文《CPM: A Large-scale Generative Chinese Pre-trained Language Model》_陈欢伯的博客-CSDN博客1. IntroductionGPT-3含有175B参数使用了570GB的数据进行训练。但大多数语料是基于英文(93%),并且GPT-3的参数没有分布,所以提出了CPM(Chinese Pretrained language Model):包含2.6B参数,使用100GB中文训练数据。CPM可以对接下游任务:对话、文章生成、完形填空、语言理解。随着参数规模的增加,CPM在一些数据集上表现更好,表示大模型在语言生成和理解上面更有效。文章的主要贡献发布了一个CPM:2.6B参数,100GB中文训练https://blog.csdn.net/mark_technology/article/details/118680728文章本身写的非常简单,至于模型结构这块,可以看一下放出来的代码,还挺好用的,我跑一个电商场景的推荐文章生成模型,效果也不错。在生成模型上还是很建议尝试一下CPM,整体采用transformer中的代码实现,比较简洁。

    中文版GPT-3来了?智源、清华发布清源 CPM——以中文为核心的大规模预训练模型

    上面计算时间为使用单块NVIDIA V100 GPU训练的估计时间。

    1.Approach

    1.1 Chinese PLM(pretrained lanuage model)

    上面是CPM的模型参数版本,其中small版本至少我是可以在gtx1080ti上训练的,后面我会添加我的具体训练参数。

    稍微过一下CPM的模型结构,其实就是gpt2的模型:

    1. transformer.wte.weight [30000, 768]
    2. transformer.wpe.weight [1024, 768]
    3. transformer.h.0.ln_1.weight [768]
    4. transformer.h.0.ln_1.bias [768]
    5. transformer.h.0.attn.bias [1, 1, 1024, 1024]
    6. transformer.h.0.attn.masked_bias []
    7. transformer.h.0.attn.c_attn.weight [768, 2304]
    8. transformer.h.0.attn.c_attn.bias [2304]
    9. transformer.h.0.attn.c_proj.weight [768, 768]
    10. transformer.h.0.attn.c_proj.bias [768]
    11. transformer.h.0.ln_2.weight [768]
    12. transformer.h.0.ln_2.bias [768]
    13. transformer.h.0.mlp.c_fc.weight [768, 3072]
    14. transformer.h.0.mlp.c_fc.bias [3072]
    15. transformer.h.0.mlp.c_proj.weight [3072, 768]
    16. transformer.h.0.mlp.c_proj.bias [768]
    17. transformer.h.1.ln_1.weight [768]
    18. transformer.h.1.ln_1.bias [768]
    19. transformer.h.1.attn.bias [1, 1, 1024, 1024]
    20. transformer.h.1.attn.masked_bias []
    21. transformer.h.1.attn.c_attn.weight [768, 2304]
    22. transformer.h.1.attn.c_attn.bias [2304]
    23. transformer.h.1.attn.c_proj.weight [768, 768]
    24. transformer.h.1.attn.c_proj.bias [768]
    25. transformer.h.1.ln_2.weight [768]
    26. transformer.h.1.ln_2.bias [768]
    27. transformer.h.1.mlp.c_fc.weight [768, 3072]
    28. transformer.h.1.mlp.c_fc.bias [3072]
    29. transformer.h.1.mlp.c_proj.weight [3072, 768]
    30. transformer.h.1.mlp.c_proj.bias [768]
    31. transformer.h.2.ln_1.weight [768]
    32. transformer.h.2.ln_1.bias [768]
    33. transformer.h.2.attn.bias [1, 1, 1024, 1024]
    34. transformer.h.2.attn.masked_bias []
    35. transformer.h.2.attn.c_attn.weight [768, 2304]
    36. transformer.h.2.attn.c_attn.bias [2304]
    37. transformer.h.2.attn.c_proj.weight [768, 768]
    38. transformer.h.2.attn.c_proj.bias [768]
    39. transformer.h.2.ln_2.weight [768]
    40. transformer.h.2.ln_2.bias [768]
    41. transformer.h.2.mlp.c_fc.weight [768, 3072]
    42. transformer.h.2.mlp.c_fc.bias [3072]
    43. transformer.h.2.mlp.c_proj.weight [3072, 768]
    44. transformer.h.2.mlp.c_proj.bias [768]
    45. transformer.h.3.ln_1.weight [768]
    46. transformer.h.3.ln_1.bias [768]
    47. transformer.h.3.attn.bias [1, 1, 1024, 1024]
    48. transformer.h.3.attn.masked_bias []
    49. transformer.h.3.attn.c_attn.weight [768, 2304]
    50. transformer.h.3.attn.c_attn.bias [2304]
    51. transformer.h.3.attn.c_proj.weight [768, 768]
    52. transformer.h.3.attn.c_proj.bias [768]
    53. transformer.h.3.ln_2.weight [768]
    54. transformer.h.3.ln_2.bias [768]
    55. transformer.h.3.mlp.c_fc.weight [768, 3072]
    56. transformer.h.3.mlp.c_fc.bias [3072]
    57. transformer.h.3.mlp.c_proj.weight [3072, 768]
    58. transformer.h.3.mlp.c_proj.bias [768]
    59. transformer.h.4.ln_1.weight [768]
    60. transformer.h.4.ln_1.bias [768]
    61. transformer.h.4.attn.bias [1, 1, 1024, 1024]
    62. transformer.h.4.attn.masked_bias []
    63. transformer.h.4.attn.c_attn.weight [768, 2304]
    64. transformer.h.4.attn.c_attn.bias [2304]
    65. transformer.h.4.attn.c_proj.weight [768, 768]
    66. transformer.h.4.attn.c_proj.bias [768]
    67. transformer.h.4.ln_2.weight [768]
    68. transformer.h.4.ln_2.bias [768]
    69. transformer.h.4.mlp.c_fc.weight [768, 3072]
    70. transformer.h.4.mlp.c_fc.bias [3072]
    71. transformer.h.4.mlp.c_proj.weight [3072, 768]
    72. transformer.h.4.mlp.c_proj.bias [768]
    73. transformer.h.5.ln_1.weight [768]
    74. transformer.h.5.ln_1.bias [768]
    75. transformer.h.5.attn.bias [1, 1, 1024, 1024]
    76. transformer.h.5.attn.masked_bias []
    77. transformer.h.5.attn.c_attn.weight [768, 2304]
    78. transformer.h.5.attn.c_attn.bias [2304]
    79. transformer.h.5.attn.c_proj.weight [768, 768]
    80. transformer.h.5.attn.c_proj.bias [768]
    81. transformer.h.5.ln_2.weight [768]
    82. transformer.h.5.ln_2.bias [768]
    83. transformer.h.5.mlp.c_fc.weight [768, 3072]
    84. transformer.h.5.mlp.c_fc.bias [3072]
    85. transformer.h.5.mlp.c_proj.weight [3072, 768]
    86. transformer.h.5.mlp.c_proj.bias [768]
    87. transformer.h.6.ln_1.weight [768]
    88. transformer.h.6.ln_1.bias [768]
    89. transformer.h.6.attn.bias [1, 1, 1024, 1024]
    90. transformer.h.6.attn.masked_bias []
    91. transformer.h.6.attn.c_attn.weight [768, 2304]
    92. transformer.h.6.attn.c_attn.bias [2304]
    93. transformer.h.6.attn.c_proj.weight [768, 768]
    94. transformer.h.6.attn.c_proj.bias [768]
    95. transformer.h.6.ln_2.weight [768]
    96. transformer.h.6.ln_2.bias [768]
    97. transformer.h.6.mlp.c_fc.weight [768, 3072]
    98. transformer.h.6.mlp.c_fc.bias [3072]
    99. transformer.h.6.mlp.c_proj.weight [3072, 768]
    100. transformer.h.6.mlp.c_proj.bias [768]
    101. transformer.h.7.ln_1.weight [768]
    102. transformer.h.7.ln_1.bias [768]
    103. transformer.h.7.attn.bias [1, 1, 1024, 1024]
    104. transformer.h.7.attn.masked_bias []
    105. transformer.h.7.attn.c_attn.weight [768, 2304]
    106. transformer.h.7.attn.c_attn.bias [2304]
    107. transformer.h.7.attn.c_proj.weight [768, 768]
    108. transformer.h.7.attn.c_proj.bias [768]
    109. transformer.h.7.ln_2.weight [768]
    110. transformer.h.7.ln_2.bias [768]
    111. transformer.h.7.mlp.c_fc.weight [768, 3072]
    112. transformer.h.7.mlp.c_fc.bias [3072]
    113. transformer.h.7.mlp.c_proj.weight [3072, 768]
    114. transformer.h.7.mlp.c_proj.bias [768]
    115. transformer.h.8.ln_1.weight [768]
    116. transformer.h.8.ln_1.bias [768]
    117. transformer.h.8.attn.bias [1, 1, 1024, 1024]
    118. transformer.h.8.attn.masked_bias []
    119. transformer.h.8.attn.c_attn.weight [768, 2304]
    120. transformer.h.8.attn.c_attn.bias [2304]
    121. transformer.h.8.attn.c_proj.weight [768, 768]
    122. transformer.h.8.attn.c_proj.bias [768]
    123. transformer.h.8.ln_2.weight [768]
    124. transformer.h.8.ln_2.bias [768]
    125. transformer.h.8.mlp.c_fc.weight [768, 3072]
    126. transformer.h.8.mlp.c_fc.bias [3072]
    127. transformer.h.8.mlp.c_proj.weight [3072, 768]
    128. transformer.h.8.mlp.c_proj.bias [768]
    129. transformer.h.9.ln_1.weight [768]
    130. transformer.h.9.ln_1.bias [768]
    131. transformer.h.9.attn.bias [1, 1, 1024, 1024]
    132. transformer.h.9.attn.masked_bias []
    133. transformer.h.9.attn.c_attn.weight [768, 2304]
    134. transformer.h.9.attn.c_attn.bias [2304]
    135. transformer.h.9.attn.c_proj.weight [768, 768]
    136. transformer.h.9.attn.c_proj.bias [768]
    137. transformer.h.9.ln_2.weight [768]
    138. transformer.h.9.ln_2.bias [768]
    139. transformer.h.9.mlp.c_fc.weight [768, 3072]
    140. transformer.h.9.mlp.c_fc.bias [3072]
    141. transformer.h.9.mlp.c_proj.weight [3072, 768]
    142. transformer.h.9.mlp.c_proj.bias [768]
    143. transformer.h.10.ln_1.weight [768]
    144. transformer.h.10.ln_1.bias [768]
    145. transformer.h.10.attn.bias [1, 1, 1024, 1024]
    146. transformer.h.10.attn.masked_bias []
    147. transformer.h.10.attn.c_attn.weight [768, 2304]
    148. transformer.h.10.attn.c_attn.bias [2304]
    149. transformer.h.10.attn.c_proj.weight [768, 768]
    150. transformer.h.10.attn.c_proj.bias [768]
    151. transformer.h.10.ln_2.weight [768]
    152. transformer.h.10.ln_2.bias [768]
    153. transformer.h.10.mlp.c_fc.weight [768, 3072]
    154. transformer.h.10.mlp.c_fc.bias [3072]
    155. transformer.h.10.mlp.c_proj.weight [3072, 768]
    156. transformer.h.10.mlp.c_proj.bias [768]
    157. transformer.h.11.ln_1.weight [768]
    158. transformer.h.11.ln_1.bias [768]
    159. transformer.h.11.attn.bias [1, 1, 1024, 1024]
    160. transformer.h.11.attn.masked_bias []
    161. transformer.h.11.attn.c_attn.weight [768, 2304]
    162. transformer.h.11.attn.c_attn.bias [2304]
    163. transformer.h.11.attn.c_proj.weight [768, 768]
    164. transformer.h.11.attn.c_proj.bias [768]
    165. transformer.h.11.ln_2.weight [768]
    166. transformer.h.11.ln_2.bias [768]
    167. transformer.h.11.mlp.c_fc.weight [768, 3072]
    168. transformer.h.11.mlp.c_fc.bias [3072]
    169. transformer.h.11.mlp.c_proj.weight [3072, 768]
    170. transformer.h.11.mlp.c_proj.bias [768]
    171. transformer.ln_f.weight [768]
    172. transformer.ln_f.bias [768]
    173. lm_head.weight [30000, 768]

    1.2 data processing

    CPM的词汇表有3w个。丰富的中文训练数据,中文数据其实比较好搞,直接网上爬就可以,git上作为提供了一个作文预训练的模型,在这个预训练模型上finetune效果也不错,我的训练数据大概有7-8w的标题-文本对数据。

    1.3 pr-training details

     lr=1.5x10-4,batch_size=3072,max_len:1024(训练时,输入数据的最大长度),steps=2000(前500轮warmup),optimizer=adam,64*v100训了2周。

    2x1080ti:cpm-small版本,max_len:200,lr=0.00015,batch_size:16,steps:100,adamw。

    transformer=4.6.0

    2.后面是cpm在一些任务上的实验。

  • 相关阅读:
    神经网络常用的训练方式,神经网络训练过程详解
    神经网络的主要内容特点,神经网络的种类和特点
    【Python】安装autopep8包,并在PyCharm中进行配置,以PEP8规范排版代码
    Mac 取消系统更新的红点——强迫症晚期患者
    mybatics 连接池-Druid
    老子云平台会员专业又有性价比!
    Verilog 之 wire与reg 类型的变量
    LeetCode 刷题 [C++] 第73题.矩阵置零
    亚马逊秋季促销指南——如何更好的利用促销?
    基于JavaWeb的大学社团管理系统的设计与实现
  • 原文地址:https://blog.csdn.net/u012193416/article/details/126040727