• 利用bert4keras实现多任务学习


    本文给出一个使用bert4keras实现多任务学习的例子

    多任务学习

    广义的讲,只要有多个loss就算Multi-task Learning。可以通过多个相关任务中的训练信息来提升模型的泛化性与表现

    代码实现

    其实多任务实现很简单,主要是在bert4keras框架的帮助下,如何去实现模型架构和设置输入输出。

    我们以实现一个层级性多元标签文本分类(层级性多元标签是什么)。为例子

    举个简单的例子,有一个电视产品,它属于“大家电”,也属于“家用电器”,而“大家电”标签是"家用电器"标签的子类,那么“家用电器”属于一级标签,“大家电” 属于二级标签,这产品所属种类标签是有层级结构。

    我们的模型要实现两个任务 一个用来预测“家用电器”一级标签,一个用来预测“大家电”二级标签。

    1. bert = build_transformer_model(config_path, checkpoint_path)
    2. # 对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。
    3. level_1_cls = Lambda(lambda x: x[:, 0], name='level_1_CLS-token')(bert.output)
    4. level_2_cls = Lambda(lambda x: x[:, 0], name='level_2_CLS-token')(bert.output)
    5. level_1_output = Dense(len(level_1_category),
    6. activation='softmax',
    7. name='level_1_output')(level_1_cls)
    8. level_2_output = Dense(len(level_2_category),
    9. activation='softmax',
    10. name='level_2_output')(level_2_cls)
    11. model = Model(bert.inputs, [level_1_output, level_2_output])
    12. losses = {
    13. "level_1_output": "categorical_crossentropy",
    14. "level_2_output": "categorical_crossentropy",
    15. }
    16. lossWeights = {"level_1_output": 1.0, "level_2_output": 1.0}
    17. model.compile(
    18. loss=losses,
    19. optimizer=Adam(learning_rate), # 用足够小的学习率
    20. loss_weights=lossWeights,
    21. metrics=['accuracy'],
    22. )

    对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。

    这里注意要对层进行命名,方便后面设置损失函数,调整损失权重。

    如果多任务中的每一个的损失都相同,可以只写一个损失代替,不用每个都列出。

    说完了模型结构,就要说数据的传入了

    model = Model(bert.inputs, [level_1_output, level_2_output])

    模型的输入是bert的输入,而输出的是两个label ,即一级标签的label和二级标签的label

    1. class data_generator(DataGenerator):
    2. """数据生成器
    3. """
    4. def __iter__(self, random=False):
    5. batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels, = [], [], [], []
    6. for is_end, (text_1, label_1, label_2) in self.sample(random):
    7. token_ids, segment_ids = tokenizer.encode(text_1, maxlen=maxlen)
    8. batch_token_ids.append(token_ids)
    9. batch_segment_ids.append(segment_ids)
    10. batch_labels.append(label_1)
    11. batch_2_labels.append(label_2)
    12. if len(batch_token_ids) == self.batch_size or is_end:
    13. batch_token_ids = sequence_padding(batch_token_ids)
    14. batch_segment_ids = sequence_padding(batch_segment_ids)
    15. batch_labels = sequence_padding(batch_labels)
    16. batch_2_labels = sequence_padding(batch_2_labels)
    17. yield [batch_token_ids,
    18. batch_segment_ids], [batch_labels, batch_2_labels]
    19. batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels = [], [], [], []

    主要注意yield [batch_token_ids,
                           batch_segment_ids], [batch_labels, batch_2_labels]处的处理就可以,其余按照情况调整即可。

    最后给出一个相关的例子,可以参考

    hgliyuhao/mixup (github.com)

  • 相关阅读:
    PyCharm+PyQT5之三界面与逻辑的分离
    Xmake v2.8.5 发布,支持链接排序和单元测试
    软件设计与体系结构——创建型模式
    JVM——5.类文件结构
    列出连通集
    网络安全深入学习第一课——热门框架漏洞(RCE-代码执行)
    操作系统——文件管理の选择题整理
    Git下载安装及环境配置,解决安装包下载慢问题(详细版)
    springboot文件下载功能开发!
    C++ Tutorials: C++ Language: Other language features: Preprocessor directives
  • 原文地址:https://blog.csdn.net/HGlyh/article/details/126719334