当前位置: 首页 > news >正文

强化学习案例复现(1)--- MountainCar基于Q-learning

1 搭建环境

1.1 gym自带

import gym# Create environment
env = gym.make("MountainCar-v0")eposides = 10
for eq in range(eposides):obs = env.reset()done = Falserewards = 0while not done:action = env.action_space.sample()obs, reward, done, action, info = env.step(action)env.render()rewards += rewardprint(rewards)

1.2 自行搭建(建议用该方法)

按照下文搭建MountainCar环境

往期文章:强化学习实践(三)基于gym搭建自己的环境(在gym0.26.2可运行)-CSDN博客

 

2.基于Q-learning的模型训练

import gym
import numpy as npenv = gym.make("GridWorld-v0")# Q-Learning settings
LEARNING_RATE = 0.1 #学习率
DISCOUNT = 0.95  #奖励折扣系数
EPISODES = 100  #迭代次数SHOW_EVERY = 1000# Exploration settings
epsilon = 1  # not a constant, qoing to be decayed
START_EPSILON_DECAYING = 1
END_EPSILON_DECAYING = EPISODES//2
epsilon_decay_value = epsilon/(END_EPSILON_DECAYING - START_EPSILON_DECAYING)DISCRETE_OS_SIZE = [20, 20]
discrete_os_win_size = (env.observation_space.high - env.observation_space.low) / DISCRETE_OS_SIZEprint(discrete_os_win_size)def get_discrete_state(state):discrete_state = (state - env.observation_space.low)/discrete_os_win_size# discrete_state = np.array(state - env.observation_space.low, dtype=float) / discrete_os_win_sizereturn tuple(discrete_state.astype(np.int64))  # we use this tuple to look up the 3 Q values for the available actions in the q-q_table = np.random.uniform(low=-2, high=0, size=(DISCRETE_OS_SIZE + [env.action_space.n]))for episode in range(EPISODES):state = env.reset()discrete_state = get_discrete_state(state)if episode % SHOW_EVERY == 0:render = Trueprint(episode)else:render = Falsedone = Falsewhile not done:if np.random.random() > epsilon:# Get action from Q tableaction = np.argmax(q_table[discrete_state])else:# Get random actionaction = np.random.randint(0, env.action_space.n)new_state, reward, done, _, c = env.step(action)new_discrete_state = get_discrete_state(new_state)# If simulation did not end yet after last step - update Q tableif not done:# Maximum possible Q value in next step (for new state)max_future_q = np.max(q_table[new_discrete_state])# Current Q value (for current state and performed action)current_q = q_table[discrete_state + (action,)]# And here's our equation for a new Q value for current state and actionnew_q = (1 - LEARNING_RATE) * current_q + LEARNING_RATE * (reward + DISCOUNT * max_future_q)# Update Q table with new Q valueq_table[discrete_state + (action,)] = new_q# Simulation ended (for any reson) - if goal position is achived - update Q value with reward directlyelif new_state[0] >= env.goal_position:# q_table[discrete_state + (action,)] = rewardq_table[discrete_state + (action,)] = 0print("we made it on episode {}".format(episode))discrete_state = new_discrete_stateif render:env.render()# Decaying is being done every episode if episode number is within decaying rangeif END_EPSILON_DECAYING >= episode >= START_EPSILON_DECAYING:epsilon -= epsilon_decay_valuenp.save("q_table.npy", arr=q_table)env.close()

3.模型测试

import gym
import numpy as npenv = gym.make("GridWorld-v0")# Q-Learning settings
LEARNING_RATE = 0.1
DISCOUNT = 0.95
EPISODES = 10DISCRETE_OS_SIZE = [20, 20]
discrete_os_win_size = (env.observation_space.high - env.observation_space.low) / DISCRETE_OS_SIZEdef get_discrete_state(state):discrete_state = (state - env.observation_space.low)/discrete_os_win_sizereturn tuple(discrete_state.astype(np.int64))  # we use this tuple to look up the 3 Q values for the available actions in the q-q_table = np.load(file="q_table.npy")for episode in range(EPISODES):state = env.reset()discrete_state = get_discrete_state(state)rewards = 0done = Falsewhile not done:# Get action from Q tableaction = np.argmax(q_table[discrete_state])new_state, reward, done, _, c = env.step(action)new_discrete_state = get_discrete_state(new_state)rewards += reward# If simulation did not end yet after last step - update Q tableif done and new_state[0] >= env.goal_position:print("we made it on episode {}, rewards {}".format(episode, rewards))discrete_state = new_discrete_stateenv.render()env.close()

http://www.lryc.cn/news/194624.html

相关文章:

  • BUUCTF学习(6): 命令执行ip
  • javaweb:mybatis:mapper(sql映射+代理开发+配置文件之设置别名、多环境配置、顺序+注解开发)
  • JavaScript基础知识——练习巩固(2)
  • FutureTask的测试使用和方法执行分析
  • SpringMVC的请求处理
  • 260. 只出现一次的数字 III
  • 家政预约接单系统,家政保洁小程序开发;
  • 网络安全工程师需要学什么?零基础怎么从入门到精通,看这一篇就够了
  • 出差学知识No3:ubuntu查询文件大小|文件包大小|磁盘占用情况等
  • 详解cv2.copyMakeBorder函数【OpenCV图像边界填充Python版本】
  • 前端技术-并发请求
  • 面试题-React(十三):React中获取Refs的几种方式
  • Linux CentOS 7升级curl8.4.0使用编译安装方式
  • 探寻JWT的本质:它是什么?它有什么作用?
  • 关于雅思听力答案限定字数的解释。
  • 化工python | CSTR连续搅拌反应器系统
  • 交通物流模型 | 基于自监督学习的交通流预测模型
  • 343. 整数拆分 96.不同的二叉搜索树
  • Vue3理解(9)
  • CRM系统中的销售漏斗有什么作用?
  • 项目(模块1:用户登陆流程分析)
  • 2023年中国商用服务机器人行业发展概况分析:国产机器人厂商向海外进军[图]
  • 千兆光模块和万兆光模块的适用场景有哪些
  • 2 files found with path ‘lib/armeabi-v7a/liblog.so‘ from inputs:
  • qt中json类
  • NeurIPS 2023 | AD-PT:首个大规模点云自动驾驶预训练方案
  • 设计模式-结构型模式
  • BUUCTF学习(7): 随便注,固网杯
  • 【文末福利】巧用Chat GPT快速提升职场能力:数据分析与新媒体运营
  • 院内导航系统厂商分析