当前位置: 首页 > news >正文

24/8/17算法笔记 策略梯度reinforce算法

import gym
from matplotlib import pyplot as plt
%matplotlib inline#创建环境
env = gym.make('CartPole-v0')
env.reset()#打印游戏
def show():plt.imshow(env.render(mode = 'rgb_array'))plt.show()
show()

定义网络模型

import torch
#定义模型
model = torch.nn.Sequential(torch.nn.Linear(4,128),torch.nn.ReLU(),torch.nn.Linear(128,2),torch.nn.Softmax(dim=1),
)
model(torch.randn(2,4))

定义动作函数

import random
#得到一个动作
def get_action(state):state = torch.FloatTensor(state).reshape(1,4)#[1,4]->[1,2]prob = model(state)#根据概率选择一个动作action = random.choice(range(2),weights = prob[0].tolist(),k=1)[0]
#这行代码从 0 到 1(包含)的整数范围内选择一个元素作为动作,选择的概率由 prob[0] 列表中元素的值决定。return action

获取一局游戏数据

def get_data():states = []rewards = []actions = []#初始化游戏state = env.reset()#玩到游戏结束为止over = Falsewhile not over:#根据当前状态得到一个动作action = get_action(state)#执行动作,得到反馈next_state,reward,over,_ = env.step(action)#记录数据样本states.append(state)rewards.append(reward)actions.append(action)#更新游戏状态,开始下一个动作state = next.statereturn states,rewards,actions

测试函数

from IPython import displaydef test(play):#初始化游戏state = env.reset()#记录反馈值的和,这个值越大越好reward_sum=0#玩到游戏结束为止over = False while not over:#根据当前状态得到一个动作action = get_action(state)#执行动作,得到反馈state,reward,over,_ = env.state(action)reward_sum += reward#打印动画if play and random.random()<0.2:#跳帧display.clear_output(wait=True) #用于清除 Jupyter Notebook 单元格的输出。show()return reward_sum

训练函数

 def train():optimizer = torch.optim.Adam(model.parameters(),lr = 1e-3)#玩N局游戏,得到数据states,rewards,actions = get_data()optimizer.zero_grad()#反馈的和,初始化为0reward_sum = 0#从最后一步算起for i in reversed(range(len(states))):#反馈的和,从最后一步的反馈开始计算#每往前一步,>>和<<都衰减0.02,然后再加上当前的反馈reward_sum*=0.98reward_sum+=rewards[i]#重新计算对应动作的概率state = torch.FloatTensor(states[i]).reshape(1,4)#[1,4]->[1,2]prob = model(state)#[1,2]->scalapron = pron[0,actions[i]]#根据求导公式,符号取反是因为这里是求loss,所以优化方向相反loss =-prob.log()*reward_sum#累积梯度loss.backward(retain_graph=True)optimizer.step()if epoch%100==0:test_result = sum([test(play=False) for _ in range(10)])/10print(epoch,test_result)
http://www.lryc.cn/news/428656.html

相关文章:

  • 【Linux学习】Linux开发工具——vim
  • 【2025校招】4399 NLP算法工程师笔试题
  • 数据库原理--关系1
  • 【人工智能】AI工程化是将人工智能技术转化为实际应用、创造实际价值的关键步骤
  • 《C语言实现各种排序算法》
  • 【888题竞赛篇】第五题,2023ICPC澳门-传送(Teleportation)
  • javascript写一个页码器-SAAS本地化及未来之窗行业应用跨平台架构
  • 微信小程序如何自定义一个组件
  • 【数学建模备赛】Ep05:斯皮尔曼spearman相关系数
  • MATLAB进行神经网络建模的案例
  • 每天一个数据分析题(四百八十九)- 主成分分析与因子分析
  • Java RPC、Go RPC、Node RPC、Python RPC 之间的互相调用
  • 国外代理IP选择:IP池的大小有何影响
  • 手机谷歌浏览器怎么用
  • Button窗口部件
  • PCIe学习笔记(25)
  • 8.20
  • centos7.9系统安装talebook个人书库
  • ES高级查询Query DSL查询详解、term术语级别查询、全文检索、highlight高亮
  • 关于Blender云渲染农场,你应该知道的一切!
  • Obsidian如何安装插件
  • Nginx服务器申请及配置免费SSL证书
  • STM32CubeMX 配置串口通信 HAL库
  • GitHub的未来:在微软领导下保持独立与AI发展的平衡
  • RGB与YUV格式详解
  • JS获取当前浏览器名称
  • 学习计算机网络(五)——ICMP协议
  • request.getRequestURI()与request.getRequestURL()的区别
  • 3154. 到达第 K 级台阶的方案数(24.8.20)
  • 如何使用docker打包后端项目并部署到阿里云k8s集群上