当前位置: 首页 > news >正文

动手学强化学习 第 11 章 TRPO 算法(TRPOContinuous) 训练代码

基于 Hands-on-RL/第11章-TRPO算法.ipynb at main · boyu-ai/Hands-on-RL · GitHub

理论 TRPO 算法

修改了警告和报错

运行环境

Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2

运行代码

TRPOContinuous.py

#!/usr/bin/env pythonimport torch
import numpy as np
import gym
import matplotlib.pyplot as plt
import torch.nn.functional as F
import rl_utils
import copyclass ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))mu = 2.0 * torch.tanh(self.fc_mu(x))std = F.softplus(self.fc_std(x))return mu, std  # 高斯分布的均值和标准差class TRPOContinuous:""" 处理连续动作的TRPO算法 """def __init__(self, hidden_dim, state_space, action_space, lmbda,kl_constraint, alpha, critic_lr, gamma, device):state_dim = state_space.shape[0]action_dim = action_space.shape[0]self.actor = PolicyNetContinuous(state_dim, hidden_dim,action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.kl_constraint = kl_constraintself.alpha = alphaself.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)mu, std = self.actor(state)action_dist = torch.distributions.Normal(mu, std)action = action_dist.sample()return [action.item()]def hessian_matrix_vector_product(self,states,old_action_dists,vector,damping=0.1):mu, std = self.actor(states)new_action_dists = torch.distributions.Normal(mu, std)kl = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists,new_action_dists))kl_grad = torch.autograd.grad(kl,self.actor.parameters(),create_graph=True)kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])kl_grad_vector_product = torch.dot(kl_grad_vector, vector)grad2 = torch.autograd.grad(kl_grad_vector_product,self.actor.parameters())grad2_vector = torch.cat([grad.contiguous().view(-1) for grad in grad2])return grad2_vector + damping * vectordef conjugate_gradient(self, grad, states, old_action_dists):x = torch.zeros_like(grad)r = grad.clone()p = grad.clone()rdotr = torch.dot(r, r)for i in range(10):Hp = self.hessian_matrix_vector_product(states, old_action_dists,p)alpha = rdotr / torch.dot(p, Hp)x += alpha * pr -= alpha * Hpnew_rdotr = torch.dot(r, r)if new_rdotr < 1e-10:breakbeta = new_rdotr / rdotrp = r + beta * prdotr = new_rdotrreturn xdef compute_surrogate_obj(self, states, actions, advantage, old_log_probs,actor):mu, std = actor(states)action_dists = torch.distributions.Normal(mu, std)log_probs = action_dists.log_prob(actions)ratio = torch.exp(log_probs - old_log_probs)return torch.mean(ratio * advantage)def line_search(self, states, actions, advantage, old_log_probs,old_action_dists, max_vec):old_para = torch.nn.utils.convert_parameters.parameters_to_vector(self.actor.parameters())old_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, self.actor)for i in range(15):coef = self.alpha ** inew_para = old_para + coef * max_vecnew_actor = copy.deepcopy(self.actor)torch.nn.utils.convert_parameters.vector_to_parameters(new_para, new_actor.parameters())mu, std = new_actor(states)new_action_dists = torch.distributions.Normal(mu, std)kl_div = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists,new_action_dists))new_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, new_actor)if new_obj > old_obj and kl_div < self.kl_constraint:return new_parareturn old_paradef policy_learn(self, states, actions, old_action_dists, old_log_probs,advantage):surrogate_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, self.actor)grads = torch.autograd.grad(surrogate_obj, self.actor.parameters())obj_grad = torch.cat([grad.view(-1) for grad in grads]).detach()descent_direction = self.conjugate_gradient(obj_grad, states,old_action_dists)Hd = self.hessian_matrix_vector_product(states, old_action_dists,descent_direction)max_coef = torch.sqrt(2 * self.kl_constraint /(torch.dot(descent_direction, Hd) + 1e-8))new_para = self.line_search(states, actions, advantage, old_log_probs,old_action_dists,descent_direction * max_coef)torch.nn.utils.convert_parameters.vector_to_parameters(new_para, self.actor.parameters())def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']),dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']),dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)rewards = (rewards + 8.0) / 8.0  # 对奖励进行修改,方便训练td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)mu, std = self.actor(states)old_action_dists = torch.distributions.Normal(mu.detach(),std.detach())old_log_probs = old_action_dists.log_prob(actions)critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()self.policy_learn(states, actions, old_action_dists, old_log_probs,advantage)num_episodes = 2000
hidden_dim = 128
gamma = 0.9
lmbda = 0.9
critic_lr = 1e-2
kl_constraint = 0.00005
alpha = 0.5
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'Pendulum-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,lmbda, kl_constraint, alpha, critic_lr, gamma, device)
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

rl_utils.py 参考

动手学强化学习 第 11 章 TRPO 算法 训练代码-CSDN博客

http://www.lryc.cn/news/411659.html

相关文章:

  • 数量关系模块
  • 滑模面、趋近律设计过程详解(滑模控制)
  • SQL Server 端口配置
  • 同一窗口还是新窗口打开链接更利于SEO优化
  • kafka 安装
  • 消息队列中间件 - Kafka:高效数据流处理的引擎
  • el-table表格动态合并相同数据单元格(可指定列+自定义合并)
  • 复习Nginx
  • nvm:Node.js 版本管理工具
  • springboot校园商店配送系统-计算机毕业设计源码68448
  • 【Redis 初阶】客户端(C++ 使用样例列表)
  • 【STM32】STM32单片机入门
  • 学生信息管理系统(Python+PySimpleGUI+MySQL)
  • Java8.0标准之重要特性及用法实例(十九)
  • Linux系统中,`buffer`和`cache` 区别
  • python创建进度条的两个手搓方法
  • JAVA—面向对象编程基础
  • 【计算机视觉学习之CV2图像操作实战:车道识别1】
  • 动态之美:Laravel动态路由参数的实现艺术
  • Python练手小项目
  • 苹果手机通讯录恢复教程?3招速成指南
  • python爬虫入门(五)之Re解析
  • 可靠的图纸加密软件,七款图纸加密软件推荐
  • 【每日一题】【最短路】【BFS】小红走矩阵 “葡萄城杯”牛客周赛 Round 53 F题 C++
  • 无线磁吸充电宝哪个牌子值得入手?什么牌子磁吸充电宝性价比高?
  • 互联网摸鱼日报(2024-08-01)
  • Alpla003经典的价量背离的因子在可转债列表里的因子分析(附python代码)
  • 进阶理解——typeof 、instanceof
  • 不同类型的生物反应器在支架成熟过程中具有哪些特点和应用?
  • 8. Spring Ai之入门到精通(超级详细)