当前位置: 首页 > news >正文

强化学习_06_pytorch-PPO2实践(ALE/Breakout-v5)

一、环境适当调整

  1. 数据收集:RecordEpisodeStatistics
  2. 进行起始跳过n帧:baseSkipFrame
  3. 一条生命结束记录为done:EpisodicLifeEnv
  4. 得分处理成0或1:ClipRewardEnv
  5. 叠帧: FrameStack
    • 图像环境的基本操作,方便CNN捕捉智能体的行动
  6. 向量空间reset处理修复
    • gym.vector.SyncVectorEnv: 原始代码中的reset是随机的
    • 继承重写的spSyncVectorEnv方法,支持每个向量的环境的seed一致,利于同一seed下环境的训练

class spSyncVectorEnv(gym.vector.SyncVectorEnv):"""step_await _terminateds reset"""def __init__(self,env_fns: Iterable[Callable[[], Env]],observation_space: Space = None,action_space: Space = None,copy: bool = True,random_reset: bool = False,seed: int = None):super().__init__(env_fns, observation_space, action_space, copy)self.random_reset = random_resetself.seed = seeddef step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:"""Steps through each of the environments returning the batched results.Returns:The batched environment step results"""observations, infos = [], {}for i, (env, action) in enumerate(zip(self.envs, self._actions)):(observation,self._rewards[i],self._terminateds[i],self._truncateds[i],info,) = env.step(action)if self._terminateds[i]:old_observation, old_info = observation, infoif self.random_reset:observation, info = env.reset(seed=np.random.randint(0, 999999))else:observation, info = env.reset() if self.seed is None else env.reset(seed=self.seed) info["final_observation"] = old_observationinfo["final_info"] = old_infoobservations.append(observation)infos = self._add_info(infos, info, i)self.observations = concatenate(self.single_observation_space, observations, self.observations)return (deepcopy(self.observations) if self.copy else self.observations,np.copy(self._rewards),np.copy(self._terminateds),np.copy(self._truncateds),infos,)

二、pytorch实践

2.1 智能体构建与训练

详细可见 Github: test_ppo_atari.Breakout_v5_ppo2_test

调整向量环境的reset 之后

  • 支持actor, criticor用同一个cnn层提取特征(PPOSharedCNN)
  • eps进行了调小->eps=0.165,希望更新的策略范围更小一些;
  • 关闭学习率衰减
  • 进行不同ent_coef的尝试: 稍微大一点,增加agent的探索;
    • ent_coef=0.015 & batch_size=256+128batch 陡降-回升慢
    • ent_coef=0.025 & batch_size=256 陡降回升-最终reward=311
    • ent_coef=0.05 & batch_size=256 -最终PPO2__AtariEnv instance__20241029__2217 reward=416
    • ent_coef=0.05 & batch_size=256+128
    • ent_coef=0.1 & batch_size=256 提升过于平缓
      在这里插入图片描述
      在这里插入图片描述
env_name = 'ALE/Breakout-v5' 
env_name_str = env_name.replace('/', '-')
gym_env_desc(env_name)
print("gym.__version__ = ", gym.__version__ )
path_ = os.path.dirname(__file__)
num_envs = 12
episod_life = True
clip_reward = True
resize_inner_area = True # True
env_pool_flag = False # True
seed = 202404
envs = spSyncVectorEnv([make_atari_env(env_name, skip=4, episod_life=episod_life, clip_reward=clip_reward, ppo_train=True, max_no_reward_count=120, resize_inner_area=resize_inner_area) for _ in range(num_envs)],random_reset=False,seed=202404
)
dist_type = 'norm'
cfg = Config(envs, save_path=os.path.join(path_, "test_models" ,f'PPO2_{env_name_str}-2'),  seed=202404,num_envs=num_envs,episod_life=episod_life,clip_reward=clip_reward,resize_inner_area=resize_inner_area,env_pool_flag=env_pool_flag,# 网络参数 Atria-CNN + MLPactor_hidden_layers_dim=[512, 256], critic_hidden_layers_dim=[512, 128], # agent参数actor_lr=4.5e-4,   gamma=0.99,# 训练参数num_episode=3600,  off_buffer_size=128,  max_episode_steps=128, PPO_kwargs={'cnn_flag': True,'clean_rl_cnn': True,'share_cnn_flag': True,'continue_action_flag': False,'lmbda': 0.95,'eps':  0.165,  # 0.165'k_epochs': 4,  #  update_epochs'sgd_batch_size': 512,  'minibatch_size': 256, 'act_type': 'relu','dist_type': dist_type,'critic_coef': 1.0, # 1.0'ent_coef': 0.05, 'max_grad_norm': 0.5,  'clip_vloss': True,'mini_adv_norm': True,'anneal_lr': False,'num_episode': 3600,}
)
minibatch_size = cfg.PPO_kwargs['minibatch_size']
max_grad_norm = cfg.PPO_kwargs['max_grad_norm']
cfg.trail_desc = f"actor_lr={cfg.actor_lr},minibatch_size={minibatch_size},max_grad_norm={max_grad_norm},hidden_layers={cfg.actor_hidden_layers_dim}",
agent = PPO2(state_dim=cfg.state_dim,actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,action_dim=cfg.action_dim,actor_lr=cfg.actor_lr,critic_lr=cfg.critic_lr,gamma=cfg.gamma,PPO_kwargs=cfg.PPO_kwargs,device=cfg.device,reward_func=None
)
agent.train()
ppo2_train(envs, agent, cfg, wandb_flag=True, wandb_project_name=f"PPO2-{env_name_str}-NEW",train_without_seed=False, test_ep_freq=cfg.off_buffer_size * 10, online_collect_nums=cfg.off_buffer_size,test_episode_count=10, add_max_step_reward_flag=False,play_func='ppo2_play',ply_env=ply_env
)

2.2 训练出的智能体观测

最后将训练的最好的网络拿出来进行观察


env = make_atari_env(env_name, skip=4, episod_life=episod_life, clip_reward=clip_reward, ppo_train=True, max_no_reward_count=120, resize_inner_area=resize_inner_area, render_mode='human')()
ppo2_play(env, agent, cfg, episode_count=2, play_without_seed=False, render=True, ppo_train=True)

在这里插入图片描述

http://www.lryc.cn/news/478992.html

相关文章:

  • 《JVM第8课》垃圾回收算法
  • SpringBoot整合Freemarker(二)
  • element plus el-form自定义验证输入框为纯数字函数
  • Android笔记(三十一):Deeplink失效问题
  • 图神经网络初步实验
  • 创建线程时传递参数给线程
  • 兴业严选|美国总统都是不良资产出身 法拍市场是否将大众化
  • C#-拓展方法
  • 加锁失效,非锁之过,加之错也|京东零售供应链库存研发实践
  • vue3 传值的几种方式
  • AUTOSAR CP NVRAM Manager规范导读
  • 2024阿里云CTF Web writeup
  • 软件著作权申请教程(超详细)(2024新版)软著申请
  • 三维测量与建模笔记 - 3.2 直接线性变换法标定DLT
  • Whisper AI视频(音频)转文本
  • 全网最详细RabbitMQ教学包括如何安装环境【RabbitMQ】RabbitMQ + Spring Boot集成零基础入门(基础篇)
  • esp32记录一次错误
  • Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
  • java-web-苍穹外卖-day1:软件开发步骤简化版+后端环境搭建
  • 一个国产 API 开源项目,在 ProductHunt 杀疯了...
  • 斗破QT编程入门系列之二:认识Qt:编写一个HelloWorld程序(四星斗师)
  • 木马病毒相关知识
  • 用 Python 写了一个天天酷跑(附源码)
  • 【网络-交换机】生成树协议、环路检测
  • C++ 中的 JSON 序列化和反序列化:结构体与枚举类型的处理
  • MySQL 批量删除海量数据的几种方法
  • 【docker入门】docker的安装
  • 单例模式五种写法
  • 解析静态链接
  • 前端基础-html-注册界面