• LLM - Model、Data、Training、Generate Agruments 超参解析


    目录

    一.引言

    二.常用参数

    ◆ ModelArguments

    ◆ DataArguments

    ◆ TrainingArguments

    ◆ GeneratingArguments

    三.代码实现

    ◆ Python 代码

    ◆ Shell 代码

    四.总结


    一.引言

    LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。

    ModelArguments - 模型参数

    DataArguments - 数据集参数

    TrainingArguments - 训练参数

    GeneratingArguments - 生成参数

    二.常用参数

    ◆ ModelArguments

    1. @dataclass
    2. class ModelArguments:
    3. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

    ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:

    参数名称默认类型含义
    model_name_or_pathNonestr模型地址或名称
    cache_dirNonestr缓存地址
    use_fast_tokenizerFalsebool使用快速 tokenizer
    padding_sideleftstr模型 pad 选择
    quantization_bitNoneint量化 bit 选择
    compute_typeNonetorch.dtype模型参数类型
    checkpoint_dirNonestr微调参数地址
    modeNonestrreward、lora
    plot_lossFalsebool打印训练 Loss

    ◆ DataArguments

    1. @dataclass
    2. class DataArguments:
    3. data_path: str = field(
    4. default=None, metadata={"help": "Path to the training data."}
    5. )

    DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:

    参数名称默认类型含义
    data_pathNonestr数据集地址
    process_numNoneint并行处理
    max_source_length512intsource 最大长度
    max_target_length512inttarget 最大长度
    max_samplesNoneint最大样本数
    ignore_pad_tokenNoneintloss 计算是否忽略
    prompt_templateNonestr样本生成 prompt 模板

    ◆ TrainingArguments

    1. @dataclass
    2. class TrainingArguments(transformers.TrainingArguments):
    3. cache_dir: Optional[str] = field(default=None)
    4. optim: str = field(default="adamw_torch")
    5. model_max_length: int = field(
    6. default=512,
    7. metadata={
    8. "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
    9. },
    10. )
    11. use_lora: bool = field(default=False)
    12. output_dir: str = field(default="")

    TrainingArguments 主要存储模型微调,训练相关的参数:

    参数名称默认类型含义
    finetuning_typelorastr微调类型
    lora_targetq_proj,v_projstr微调 Layer
    lora_rank8intlora 降维维度
    lora_alpha32.0floatlora 微调比例因子
    lora_dropout0.1floatdropout 比例
    num_hidden_layers32intDecode 数量
    num_layer_trainable3intfreeze layer 数量
    name_module_trainablemlpstrfreeze 训练层选择
    output_dirNonestr模型输出地址

    ◆ GeneratingArguments

    1. @dataclass
    2. class GeneratingArguments:
    3. do_sample: Optional[bool] = field(
    4. default=True,
    5. metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
    6. )

    GeneratingArguments 主要负责 model generate 生成的配置:

    参数名称默认类型含义
    do_sampleTruebool采样或贪心
    temperature0.95float调整下一个 token 的概率
    top_p0.7floattoken 概率 top 区间
    top_k50inttoken 词库数量
    num_beams1intbeam search 数量
    max_lengthNoneint最大生成 token 数
    max_new_tokens512int最多新 toekn 生成数
    repatition_penalty1.0float重复惩罚
    length_penalty1.0float长度惩罚

    之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本

    三.代码实现

    ◆ Python 代码

    1. from typing import Optional
    2. from dataclasses import dataclass, field
    3. import transformers
    4. ...
    5. 添加上述的 Argument Class
    6. ...
    7. if __name__ == '__main__':
    8. parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
    9. model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
    10. print(model_args)
    11. print(data_args)
    12. print(training_args)
    13. print(generate_args)

    两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。

    Shell 代码

    1. #!/bin/bash
    2. python GetConfigByArgs.py \
    3. --report_to "none" \
    4. --data_path "data/belle_chat_ramdon_10k.json" \
    5. --model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
    6. --output_dir "output" \
    7. --model_max_length 512 \
    8. --num_train_epochs 4 \
    9. --per_device_train_batch_size 16 \
    10. --gradient_accumulation_steps 1 \
    11. --save_strategy epoch \
    12. --learning_rate 2e-5 \
    13. --lr_scheduler_type constant \
    14. --adam_beta1 0.9 \
    15. --adam_beta2 0.98 \
    16. --adam_epsilon 1e-8 \
    17. --max_grad_norm 1.0 \
    18. --weight_decay 1e-4 \
    19. --warmup_ratio 0.0 \
    20. --logging_steps 1 \
    21. --gradient_checkpointing True \
    22. --deepspeed ds_config.json \
    23. --bf16 False \
    24. --tf32 False

    通过 -- 传递我们需要的参数即可。

    四.总结

    这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。

  • 相关阅读:
    Java SE 13 新增特性
    聊一聊DTM子事务屏障功能之SQL Server版
    大数据ClickHouse进阶(六):Distributed引擎深入了解
    django自动生成问卷表的软件的设计与实现毕业设计源码291138
    【电子学会】2023年05月图形化三级 -- 数星星
    一个诡异的 Pulsar InterruptedException 异常
    不用定时器,实现鼠标长悬浮和鼠标长按监听
    Python 网络爬取的时候使用那种框架
    ESP8266-Arduino编程实例-MLX90615红外测温仪驱动
    搭建rtmp流媒体服务器的步骤
  • 原文地址:https://blog.csdn.net/BIT_666/article/details/132755841