• 聊聊GLM-4-9B开源模型的微调loss计算


    概述

    Github官方地址:GLM-4

    网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

    可了解其它loss计算的文章:
    再聊多轮对话微调训练格式与长序列训练
    聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析
    聊聊大模型多轮对话的训练及优化

    微调

    微调格式:

    [
      {
        "messages": [
          {
            "role": "system",
            "content": "",
            "tools": [
              {
                "name": "",
                "args": {
                  "": ""
                }
              }
            ]
          },
          {
            "role": "user",
            "content": ""
          },
          {
            "role": "assistant",
            "content": ""
          },
          {
            "role": "user",
            "content": ""
          },
          {
            "role": "assistant",
            "content": ""
          },
          {
            "role": "observation",
            "content": ""
          },
          {
            "role": "assistant",
            "content": ""
          },
          {
            "role": "user",
            "content": ""
          },
          {
            "role": "assistant",
            "content": ""
          }
        ]
      }
    ]
    

    微调源码地址:finetune.py
    Loss计算代码:

    def process_batch(
            batch: Mapping[str, Sequence],
            tokenizer: PreTrainedTokenizer,
            max_input_length: int,
            max_output_length: int,
    ) -> dict[str, list]:
        batched_conv = batch['messages']
        batched_input_ids = []
        batched_labels = []
        # batched_conv 是一个数组
        # conv 是数组内的单个 message
        for conv in batched_conv:
            input_ids = [151331, 151333]
            loss_masks = [False, False]
            # conv 是数组内的单个 message
            # message 是 单个role json对象
            for message in conv:
                message = process_message(message)
                # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
                loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
                # 获取 input 文本的数字表示(ids)
                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
                # 计算整句的 mask
                new_loss_masks = [loss_mask_val] * len(new_input_ids)
                # 拼接message中的每段json
                input_ids += new_input_ids
                # 拼接message中每段json对应的mask
                loss_masks += new_loss_masks
            # 追加结尾的 token id
            input_ids.append(tokenizer.eos_token_id)
            loss_masks = [False, *loss_masks]
            labels = []
            for input_id, mask in zip(input_ids, loss_masks):
                if mask:
                    # 添加到label,计算loss
                    labels.append(input_id)
                else:
                    # -100 不处理,即ignore_index
                    labels.append(-100)
            max_length = max_input_length + max_output_length + 1
            # 截断
            batched_input_ids.append(input_ids[:max_length])
            batched_labels.append(labels[:max_length])
        return {'input_ids': batched_input_ids, 'labels': batched_labels}
    
    

    注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    data_manager = DataManager(data_dir, ft_config.data_config)
    # 数据集拆分遍历
    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    print('train_dataset:', train_dataset)
    

    Loss计算如下图所示:

    总结

    相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md

    原文链接:https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg

    更多大模型相关的文章,请上个人公众号查阅:
    image

  • 相关阅读:
    设计模式:模板模式和策略模式混合使用
    同样是巡检,巡检系统在不同行业运用大不同
    Python多种方法实现九九乘法表
    2017-04《信息资源管理 02378》真卷,圈定章节考点+统计真题分布
    葫芦娃解析
    Java设计模式之代理模式(一)
    字符串的创建(直接赋值与new的区别)- 字符串常量池
    如何阅读计算机学术文献?
    Python简直是万能的,这5大主要用途你一定要知道!
    深入解析 MySQL binlog
  • 原文地址:https://www.cnblogs.com/zhiyong-ITNote/p/18243420