码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略


    Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

    目录

    trl的简介

    1、亮点

    2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

    trl的安装

    trl的使用方法

    1、基础用法

    (1)、如何使用库中的SFTTrainer

    (2)、如何使用库中的RewardTrainer

    (3)、如何使用库中的PPOTrainer

    2、进阶用法

    LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

    LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


    trl的简介

              TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

    GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

    1、亮点

    >> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

    >> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

    >> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

    >> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

    >> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

    2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

    通过PPO对语言模型进行微调大致包括三个步骤:

    Rollout

    Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

    Evaluation

    Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值。

    Optimization

    Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

    这个过程在下面的示意图中说明。

    trl的安装

    pip install trl

    trl的使用方法

    1、基础用法

    (1)、如何使用库中的SFTTrainer

    以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

    1. # imports
    2. from datasets import load_dataset
    3. from trl import SFTTrainer
    4. # get dataset
    5. dataset = load_dataset("imdb", split="train")
    6. # get trainer
    7. trainer = SFTTrainer(
    8. "facebook/opt-350m",
    9. train_dataset=dataset,
    10. dataset_text_field="text",
    11. max_seq_length=512,
    12. )
    13. # train
    14. trainer.train()

    (2)、如何使用库中的RewardTrainer

    以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

    1. # imports
    2. from transformers import AutoModelForSequenceClassification, AutoTokenizer
    3. from trl import RewardTrainer
    4. # load model and dataset - dataset needs to be in a specific format
    5. model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
    6. tokenizer = AutoTokenizer.from_pretrained("gpt2")
    7. ...
    8. # load trainer
    9. trainer = RewardTrainer(
    10. model=model,
    11. tokenizer=tokenizer,
    12. train_dataset=dataset,
    13. )
    14. # train
    15. trainer.train()

    (3)、如何使用库中的PPOTrainer

    以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

    1. # imports
    2. import torch
    3. from transformers import AutoTokenizer
    4. from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
    5. from trl.core import respond_to_batch
    6. # get models
    7. model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
    8. model_ref = create_reference_model(model)
    9. tokenizer = AutoTokenizer.from_pretrained('gpt2')
    10. # initialize trainer
    11. ppo_config = PPOConfig(
    12. batch_size=1,
    13. )
    14. # encode a query
    15. query_txt = "This morning I went to the "
    16. query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
    17. # get model response
    18. response_tensor = respond_to_batch(model, query_tensor)
    19. # create a ppo trainer
    20. ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
    21. # define a reward for response
    22. # (this could be any reward such as human feedback or output from another model)
    23. reward = [torch.tensor(1.0)]
    24. # train model for one step with ppo
    25. train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

    2、进阶用法

    LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

    https://yunyaniu.blog.csdn.net/article/details/133865725

    LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

    https://yunyaniu.blog.csdn.net/article/details/133873621

  • 相关阅读:
    SpringBoot项目如何优雅的实现操作日志记录
    SpringBoot打包的两种方式 - jar方式 和 war 方式
    11.10 知识总结(数据的增删改查、如何创建表关系、Django框架的请求生命周期流程图)
    Redis事务、pub/sub、PipeLine-管道、benchmark性能测试详解
    深入浅出PyTorch——主要模块和基础实战
    安信可IDE(AiThinker_IDE)编译ESP8266工程方法
    科技资讯|微软AR眼镜新专利曝光,可拆卸电池解决续航焦虑
    如何流畅进入Github
    免费低代码平台,助企业高效管理任务
    机房动环监控系统有哪些告警功能,机房动环监控系统是什么?
  • 原文地址:https://blog.csdn.net/qq_41185868/article/details/133865134
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号