当前位置: 首页 > news >正文

【FreeRL】我的深度学习库构建思想

文章目录

  • 前言
    • 参考
    • python环境
    • 效果
    • 已复现结果
  • 综述
    • DQN.py(主要)
      • 算法实现
      • 参数修改
      • 细节实现
      • 显示训练,保存训练
    • Buffer.py
    • evaluate.py
    • learning_curves


前言

代码实现在:https://github.com/wild-firefox/FreeRL
欢迎star

参考

  • 动手学强化学习
  • elegentRL
  • DRL-code-pytorch
  • easy-rl
  • maddpg-pettingzoo-pytorch
  • 深度强化学习
  • reinforcement-learning-algorithm
  • DRL-Pytorch
  • cleanRL

目的是写出像TD3作者那样简单易懂的DRL代码,
由于参考了ElegentRL和Easy的库,from easy to elegent 故起名为freeRL,
free也是希望写出的代码可以随意的,自由的从此代码移植到自己的代码上。

python环境

python 3.11.9
torch 2.3.1+cu121
gymnasium[all] 0.29.1
pygame 0.25.2 # 这个版本和gymnasium[all]0.29.1兼容

效果

在参数没有精细调整的情况下,在大多数的环境已能适用。
用DQN算法在LunarLander-v2环境下训练500个轮次的3个seed的效果:线为均值,阴影为方差
在这里插入图片描述
用 seed = 0 训练的模型评估,评估100个不同的seed的结果。
在这里插入图片描述
随机选择其中一个seed的结果,渲染环境。
在这里插入图片描述

已复现结果

1.DQN
2.DQN_Double
3.DQN_Dueling
4.DQN_PER
5.DQN_Noisy
6.DQN_N_Step
7.DQN_Categorical
8.DQN_Rainbow

其中:
1 实现在DQN_file/DQN.py
2-8 实现在DQN_file/DQN_with_tricks.py

在这里插入图片描述

综述

为了便于对算法的理解和改动,我将一个整体的算法训练和评估分离开来。

DQN_file
├── learning_curves
│   ├── env_name_1
│	│   ├── DQN_3_seed.npy
│   │   └── DQN.png
│   └── env_name_2
├── results
│   ├── env_name_1
│	│	├── DQN_1
│	│	│	├── DQN_seed_0.npy
│	│	│	├── DQN.pt
│	│	│	├── evaluate.gif
│	│	│	├── evaluate.png
│	│	│	└── events.out.tfevents.
│	│	├── DQN_2
│	│	└── DQN_3
│   └── env_name_2
├── plot_learning_curves.py
├── evaluate.py
├── Buffer.py
└── DQN.py

首先看最下面几个具体的py文件
1.evaluate.py 实现评估。
2.plot_learning_curves.py实现多个seed的学习曲线的绘制和算法比较。
3.DQN.py 实现算法。
4.Buffer.py 实现经验池,经验池基本通用。

以DQN.py为算法.py举例

DQN.py(主要)

建议边打开github上DQN.py的代码边看。

算法实现

一个深度强化学习算法分三个部分实现:
1.Agent类:包括actor、critic、target_actor、target_critic、actor_optimizer、critic_optimizer、
2.DQN算法类:包括select_action,learn、save、load等方法,为具体的算法细节实现
3.main函数:实例化DQN类,主要参数的设置,训练、测试、保存模型等

这三个部分均在DQN.py里实现。

参数修改

参数修改 改三处:
1.MLP的hidden (此参数往往在第一部分开头实现)
2.main中args
3.dis_to_con中的离散转连续空间维度(针对无法转成连续域的算法,例:DQN)

对于1.需要单独修改的理由
hidden的层数和个数容易变化,且RL的许多的算法创新实现在MLP(Qnet,Actor,Critic处)会有新增参数。
对于2.
args 为主要的参数,算法独有或共有或保存位置的修改。
对于3.
主要针对DQN只能对离散环境适用,不能对连续环境适用,进行的转换。
将动作分配成多维的离散动作,使得算法可以适用,相对的,在采样环境时,需要将离散的动作转换成连续的动作。

基本的参数没有精细调整,这里DQN使用离散环境MountainCar-v0为基准来调整参数,以此能收敛为目标了,后发现此参数可以适用大多数其他环境,但不是全部。
使用MountainCar-v0的理由:环境的目标是到达最高的山峰,但环境中还有个次高的山峰,个人认为可以很好拟合出梯度中的次优解。

细节实现

1.对于不同的算法的实现,在代码中给出论文链接和不同实现。
2.在RL中使用常用的,通用的pytorch代码,易懂。见:【深度强化学习】常常使用的pytorch代码
3.区分env的terminated,truncated
4.区分训练时用的action(例:(-1,1))和env能接受的action_(例:(-3,3))
(区分3,4两点对于收敛有很大帮助。)
5.区分环境采样过程和训练过程,以提高算法的拓展性。
6.以max_episodes为终止条件,但是训练以step为最小单位。

显示训练,保存训练

1.训练时,使用tensorboard来显示实时的学习曲率。

在DQN_file(算法)文件夹下,D:FreeRL/DQN_file 终端里输入:
tensorboard --logdir=results/env_name
在跳出的http://localhost:6008/ 按住ctrl点击进入就行。

tensorboard保存的文件events.out.tfevents.和模型的位置一致。

保存模型的频率设置为总回合的1/4。

2.在results文件夹下,不同环境为文件夹名下,在算法(或算法+trick)为文件夹名里,(results/env_name/DQN_1)保存模型文件(DQN.pt)及其训练时每个episode的return值,以不同seed为区分(DQN_seed_0.npy)(此npy用于后续画学习曲率)

每进行一次训练文件夹后面的数DQN_n,n+1。

Buffer.py

在创建buffer时直接使用zeros来创建,比使用deque来创建在最后使用python基本数据再转成numpy再转成tensor速度要快。
这里使用numpy实现来使它更快一点。(参考elegentrl)
在这里插入图片描述

其他一些buffer的实现,都实现在此。

evaluate.py

实现对模型的评估,可设定评估的轮次数,设定是否保存渲染环境gif。

这里seed的设定值须与训练的seed值不同。
由于gymnasium可以设定env的seed。这里将环境的seed值设定为当前遍历的轮次,以实现seed的改变。
在gymnasium中,如果有实现任务所达到的return值,在画评估图时,以此为基线。

环境gif的保存,则是随机挑选其中一个回合进行保存。

此代码所得到的evaluate.png,evaluate.gif均保存在模型所在位置。(results/env/DQN_1/下)

(上述效果的最后两个图)

learning_curves

1.将不同的results/env/algorithm_trick_n下的DQN_seed_n.npy绘制成一个学习曲线
以均值为线,阴影为方差。
2.将比较的多个seed的episode_return 另保存为DQN_3_seed.npy方便后续比较。
3.可以选择是否比较此算法的其他trick算法。

可以设置seed_num大小,取决于你在环境的测试中,实验了几次不同的seed大小,这里仅使用seed =
0,10,100来进行绘制,当然也可以只进行一个seed的绘制。(这里有进行平滑处理,可以设置)

生成的学习曲线图为DQN.py 和保存的DQN_3_seed.npy保存在learning_curves/env/下

(上述效果的第一张图为学习曲线图,已复现的结果为比较图)

http://www.lryc.cn/news/436410.html

相关文章:

  • Docker部署nginx容器无法访问80端口
  • Python语言开发学习之使用Python预测天气
  • minio实现大文件断点续传
  • Qt绘制动态仪表(模仿汽车仪表指针、故障灯)
  • 【视频教程】GEE遥感云大数据在林业中的应用与典型案例实践
  • 【时时三省】c语言例题----华为机试题<字符串排序>
  • 基于vue框架的城市体育运动交流平台15s43(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。
  • 2024年软件测试经典大厂面试题(全3套)【包含答案】
  • What is Node.JS and its Pros and Cons
  • TestCraft - GPT支持的测试想法生成器和自动化测试生成器
  • FreeRTOS内部机制学习04(任务通知和软件定时器)
  • 华为eNSP :WLAN的配置
  • 中国大数据产业的融资热潮来袭,哪些领域最受资本青睐?
  • Unity数据持久化 之 使用Excel.DLL读写Excel表格
  • Linux系统:chown命令
  • Unity3D ARPG(动作角色扮演游戏)设计与实现详解
  • Qt实现登录界面
  • big.LITTLE
  • 汤臣倍健,三七互娱,得物,顺丰,快手,游卡,oppo,康冠科技,途游游戏,埃科光电25秋招内推
  • 再谈c++模板
  • 9.11 codeforces Div 2
  • 二级菜单的两种思路(完成部分)
  • 【机器学习导引】ch2-模型评估与选择
  • 二开ihoneyBakFileScan备份扫描
  • leetcode21. 合并两个有序链表
  • 搭建 WordPress 及常见问题与解决办法
  • 《ORANGE‘s 一个操作系统的实现》--保护模式进阶
  • 【可变参模板】可变参类模板
  • Linux 递归删除大量的文件
  • 设计一个算法,找出由str1和str2所指向两个链表共同后缀的起始位置