码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • d3rlpy离线强化学习算法库安装及使用


    GitHub - takuseno/d3rlpy: An offline deep reinforcement learning library

    d3rlpy,离线强化学习算法库

    我装在windows下用anaconda,按照官网教程

    conda install -c conda-forge d3rlpy

    第一次安装报错CondaSSLError: OpenSSL appears to be unavailable on this machine

    [报错解决]CondaSSLError: OpenSSL appears to be unavailable on this machine. OpenSSL is required to downl_一件迷途小书童的博客-CSDN博客

    参考这篇文章解决后正常安装没问题,值得注意的是d3rkpy安装时包含cudatoolkit11.几,我在想这个在不同电脑上可能之后会出错,不过后面运行算法时可以选择是否使用GPU

    我是打算用离线强化学习算法,安装后测试,官网上也有测试代码

    1. import d3rlpy
    2. # prepare dataset
    3. dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    4. # prepare algorithm
    5. cql = d3rlpy.algos.CQL(use_gpu=True)
    6. # train
    7. cql.fit(
    8. dataset,
    9. eval_episodes=dataset,
    10. n_epochs=100,
    11. scorers={
    12. 'environment': d3rlpy.metrics.evaluate_on_environment(env),
    13. 'td_error': d3rlpy.metrics.td_error_scorer,
    14. },
    15. )

    看得出来,这接口用起来非常方便啊

    因为我没装d4rl所以肯定是失败了,d4rl数据集查了下资料可能无法装在windows环境下,有点难办。可以使用下面这个在测试,用的是d3rlpy自带用于测试的数据集,也是比较常用的两个环境,具体是在d3rlpy的文档上找到的

    1. import d3rlpy
    2. # prepare dataset
    3. # dataset, env = d3rlpy.datasets.get_d4rl('CartPole-v0')
    4. dataset, env = d3rlpy.datasets.get_pendulum("random")
    5. # prepare algorithm
    6. cql = d3rlpy.algos.CQL(use_gpu=True)
    7. # train
    8. cql.fit(
    9. dataset,
    10. eval_episodes=dataset,
    11. n_epochs=100,
    12. scorers={
    13. 'environment': d3rlpy.metrics.evaluate_on_environment(env),
    14. 'td_error': d3rlpy.metrics.td_error_scorer,
    15. },
    16. )

    资料很充分,d3rlpy文档:d3rlpy.datasets.get_cartpole — d3rlpy documentation

     成功运行:

    如果失败的话可能是下载失败,

    在这找到下载网址,自己下载到本地,改成规定的名字即可,放到对d3rlpy_data文件夹里,再运行时就不需要在线下载了,比如这样

     

    之后回到d4rl,我打算把自己的数据集按照d4rl的格式来编写,但我不打算装d4rl

    可以看到在d3rlpy中读取d4rl的数据集主要是用d4rl中的get_dataset函数,于是我索性把d4rl中这个函数搬到d3rlpy中,其实就是读取h5格式的函数,也挺好移植,主要也就这一段

    1. data_dict = {}
    2. with h5py.File(h5path, 'r') as dataset_file:
    3. for k in tqdm(get_keys(dataset_file), desc="load datafile"):
    4. try: # first try loading as an array
    5. data_dict[k] = dataset_file[k][:]
    6. except ValueError as e: # try loading as a scalar
    7. data_dict[k] = dataset_file[k][()]

    注意还需要

    1. import h5py
    2. from tqdm import tqdm

    和

    1. def get_keys(h5file):
    2. keys = []
    3. def visitor(name, item):
    4. if isinstance(item, h5py.Dataset):
    5. keys.append(name)
    6. h5file.visititems(visitor)
    7. return keys

    至于原先是个类,我感觉好像也不需要,同时还是把在线改掉,直接变成一个绝对位置(这个在d4rl中也可以找到下载的网址)

    h5path = "D:\xxx_project\pycharm\offline_RL\d3rlpy_data\hopper_random.hdf5"

    运行成功

    我考虑下一步制作自己的hdf5格式数据集,及做下自己的gym环境

    甚至不能算是入门,希望没有问题,欢迎指正

  • 相关阅读:
    电网数字孪生解决方案助力智慧电网体系建设
    富文本编辑器——UEditor的使用——基础积累
    OSPFv2特殊区域---Stub区域
    双边滤波算法及例程
    DNS解析
    数据结构——堆排序
    Visual Studio 2022 开发 STM32 单片机 - 环境搭建点亮LED灯
    windows远程连接linux并实现上传下载文件,不需要额外安装任何软件~
    2023年中国青少年近视管理离焦镜片市场零售量、零售额及发展趋势分析[图]
    JAVA this关键词作用
  • 原文地址:https://blog.csdn.net/Already8888/article/details/128173013
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号