当前位置: 首页 > news >正文

训练强化学习的经验回放策略:experience replay

经验回放:Experience Replay(训练DQN的一种策略)


优点:可以重复利用离线经验数据;连续的经验具有相关性,经验回放可以在离线经验BUFFER随机抽样,减少相关性;

超参数:Replay Buffer的长度;
∙ Find w by minimizing  L ( w ) = 1 T ∑ t = 1 T δ t 2 2 . ∙ Stochastic gradient descent (SGD): ∙ Randomly sample a transition,  ( s i , a i , r i , s i + 1 ) , from the buffer ∙ Compute TD error,  δ i . ∙ Stochastic gradient: g i = ∂ δ i 2 / 2 ∂ w = δ i ⋅ ∂ Q ( s i , a i ; w ) ∂ w ∙ SGD: w ← w − α ⋅ g i . \begin{aligned} &\bullet\text{ Find w by minimizing }L(\mathbf{w})=\frac{1}{T}\sum_{t=1}^{T}\frac{\delta_{t}^{2}}{2}. \\ &\bullet\text{ Stochastic gradient descent (SGD):} \\ &\bullet\text{ Randomly sample a transition, }(s_i,a_i,r_i,s_{i+1}),\text{from the buffer} \\ &\bullet\text{ Compute TD error, }\delta_i. \\ &\bullet\text{ Stochastic gradient: g}_{i}=\frac{\partial\delta_{i}^{2}/2}{\partial \mathbf{w}}=\delta_{i}\cdot\frac{\partial Q(s_{i},a_{i};\mathbf{w})}{\partial\mathbf{w}} \\ &\bullet\text{ SGD: w}\leftarrow\mathbf{w}-\alpha\cdot\mathbf{g}_i. \end{aligned}  Find w by minimizing L(w)=T1t=1T2δt2. Stochastic gradient descent (SGD): Randomly sample a transition, (si,ai,ri,si+1),from the buffer Compute TD error, δi. Stochastic gradient: gi=wδi2/2=δiwQ(si,ai;w) SGD: wwαgi.


注:实践中通常使用minibatch SGD,每次抽取多个经验,计算小批量随机梯度;
Replay Buffer代码实现如下:

@dataclass
class ReplayBuffer:maxsize: intsize: int = 0state: list = field(default_factory=list)action: list = field(default_factory=list)next_state: list = field(default_factory=list)reward: list = field(default_factory=list)done: list = field(default_factory=list)def push(self, state, action, reward, done, next_state):""":param state: 状态:param action: 动作:param reward: 奖励:param done::param next_state:下一个状态:return:"""if self.size < self.maxsize:self.state.append(state)self.action.append(action)self.reward.append(reward)self.done.append(done)self.next_state.append(next_state)else:position = self.size % self.maxsizeself.state[position] = stateself.action[position] = actionself.reward[position] = rewardself.done[position] = doneself.next_state[position] = next_stateself.size += 1def sample(self, n):total_number = self.size if self.size < self.maxsize else self.maxsizeindices = np.random.randint(total_number, size=n)state = [self.state[i] for i in indices]action = [self.action[i] for i in indices]reward = [self.reward[i] for i in indices]done = [self.done[i] for i in indices]next_state = [self.next_state[i] for i in indices]return state, action, reward, done, next_state

训练时的代码如下:

离线数据放到BUFFER里面:

#动作、状态、奖励、结束标志、下一状态
replay_buffer.push(state, action, reward, done, next_state)

训练时采样然后计算损失

bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)loss = agent.compute_loss(bs, ba, br, bd, bns)
loss.backward()
optimizer.step()
optimizer.zero_grad()
http://www.lryc.cn/news/111578.html

相关文章:

  • uniapp学习
  • 机器学习深度学习——数值稳定性和模型化参数(详细数学推导)
  • layui 整合UEditor 百度编辑器
  • 1、sparkStreaming概述
  • 【Spring Boot】Spring Boot 集成 RocketMQ 实现简单的消息发送和消费
  • uniapp:图片验证码检验问题处理
  • 将Visio和Excel导出成没有白边的PDF文件
  • String类及其工具类
  • 踩坑(5)整合kafka 报错 java.net.UnknownHostException: 不知道这样的主机
  • rust持续学习 get_or_insert_with
  • 卡尔曼滤波 | Matlab实现无迹kalman滤波仿真
  • C++---list常用接口和模拟实现
  • [openCV]基于赛道追踪的智能车巡线方案V1
  • SpringIoc-个人学习笔记
  • 【一文搞懂泛型】
  • 概念解析 | 利用MIMO雷达技术实现高性能目标检测的关键技术解析
  • Grafana制作图表-自定义Flink监控图表
  • 【TypeScript】初识TypeScript和变量类型介绍
  • 阿里云瑶池 PolarDB 开源官网焕新升级上线
  • 泡水书为什么不能再出售
  • Mac 执行 .sh命令报错 command not found
  • postgresql 使用之 存储架构 触摸真实数据的存储结构以及组织形式,存入数据库的数据原来在这里
  • Node.Js安装与配置教程
  • Element-Plus DatePicker获取时间戳
  • 【算法第十五天7.29】513.找树左下角的值 112. 路径总和 106.从中序与后序遍历序列构造二叉树
  • Java thymeleaf bug排查记录
  • 互感和励磁电感(激磁电感)的关系
  • stdexcept和exception,两个头文件的区别?
  • openCV图像的读写操作
  • Android平台GB28181设备接入端如何降低资源占用和性能消耗