当前位置: 首页 > news >正文

Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning.
This code refers to Mofan’s reinforcement learning course.

文章目录

  • Reinforcement Learning with Code 【Code 2. Tabular Sarsa】
    • 2.1 Problem and result
    • 2.2 Environment
    • 2.3 Tabular Sarsa Algorithm
    • 2.4 Run this main
    • 2.5 Check the Q table
    • Reference

2.1 Problem and result

Please consider the problem that a little mouse (denoted by red block) wants to avoid trap (denoted by black block) to get the cheese (denoted by yellow circle). As the figure shows.

Image

This chapter aims to realize tabular Sarsa algorithm sovle this problem.

2.2 Environment

We use the tkinter package of python to build our environment to interact with agent.

import numpy as np
import time
import sys
import tkinter as tk
# if sys.version_info.major == 2: # 检查python版本是否是python2
#     import Tkinter as tk
# else:
#     import tkinter as tkUNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid widthclass Maze(tk.Tk, object):def __init__(self):super(Maze, self).__init__()# Action Spaceself.action_space = ['up', 'down', 'right', 'left'] # action space self.n_actions = len(self.action_space)# 绘制GUIself.title('Maze env')self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))   # 指定窗口大小 "width x height"self._build_maze()def _build_maze(self):self.canvas = tk.Canvas(self, bg='white',height=MAZE_H * UNIT,width=MAZE_W * UNIT)     # 创建背景画布# create gridsfor c in range(UNIT, MAZE_W * UNIT, UNIT): # 绘制列分隔线x0, y0, x1, y1 = c, 0, c, MAZE_H * UNITself.canvas.create_line(x0, y0, x1, y1)for r in range(UNIT, MAZE_H * UNIT, UNIT): # 绘制行分隔线x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, rself.canvas.create_line(x0, y0, x1, y1)# create origin 第一个方格的中心,origin = np.array([UNIT/2, UNIT/2]) # hell1hell1_center = origin + np.array([UNIT * 2, UNIT])self.hell1 = self.canvas.create_rectangle(hell1_center[0] - (UNIT/2 - 5), hell1_center[1] - (UNIT/2 - 5),hell1_center[0] + (UNIT/2 - 5), hell1_center[1] + (UNIT/2 - 5),fill='black')# hell2hell2_center = origin + np.array([UNIT, UNIT * 2])self.hell2 = self.canvas.create_rectangle(hell2_center[0] - (UNIT/2 - 5), hell2_center[1] - (UNIT/2 - 5),hell2_center[0] + (UNIT/2 - 5), hell2_center[1] + (UNIT/2 - 5),fill='black')# create oval 绘制终点圆形oval_center = origin + np.array([UNIT*2, UNIT*2])self.oval = self.canvas.create_oval(oval_center[0] - (UNIT/2 - 5), oval_center[1] - (UNIT/2 - 5),oval_center[0] + (UNIT/2 - 5), oval_center[1] + (UNIT/2 - 5),fill='yellow')# create red rect 绘制agent红色方块,初始在方格左上角self.rect = self.canvas.create_rectangle(origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),fill='red')# pack all 显示所有canvasself.canvas.pack()def get_state(self, rect):# convert the coordinate observation to state tuple# use the uniformed center as the state such as # |(1,1)|(2,1)|(3,1)|...# |(1,2)|(2,2)|(3,2)|...# |(1,3)|(2,3)|(3,3)|...# |....x0,y0,x1,y1 = self.canvas.coords(rect)x_center = (x0+x1)/2y_center = (y0+y1)/2state = ((x_center-(UNIT/2))/UNIT + 1, (y_center-(UNIT/2))/UNIT + 1)return statedef reset(self):self.update()self.after(500) # delay 500msself.canvas.delete(self.rect)   # delete origin rectangleorigin = np.array([UNIT/2, UNIT/2])self.rect = self.canvas.create_rectangle(origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),fill='red')# return observation return self.get_state(self.rect)   def step(self, action):# agent和环境进行一次交互s = self.get_state(self.rect)   # 获得智能体的坐标base_action = np.array([0, 0])reach_boundary = Falseif action == self.action_space[0]:   # upif s[1] > 1:base_action[1] -= UNITelse: # 触碰到边界reward=-1并停留在原地reach_boundary = Trueelif action == self.action_space[1]:   # downif s[1] < MAZE_H:base_action[1] += UNITelse:reach_boundary = True   elif action == self.action_space[2]:   # rightif s[0] < MAZE_W:base_action[0] += UNITelse:reach_boundary = Trueelif action == self.action_space[3]:   # leftif s[0] > 1:base_action[0] -= UNITelse:reach_boundary = Trueself.canvas.move(self.rect, base_action[0], base_action[1])  # move agents_ = self.get_state(self.rect)  # next state# reward functionif s_ == self.get_state(self.oval):     # reach the terminalreward = 1done = Trues_ = 'success'elif s_ == self.get_state(self.hell1): # reach the blockreward = -1s_ = 'block_1'done = Falseelif s_ == self.get_state(self.hell2):reward = -1s_ = 'block_2'done = Falseelse:reward = 0done = Falseif reach_boundary:reward = -1return s_, reward, donedef render(self):time.sleep(0.15)self.update()if __name__ == '__main__':def test():for t in range(10):s = env.reset()print(s)while True:env.render()a = 'right's, r, done = env.step(a)print(s)if done:breakenv = Maze()env.after(100, test)      # 在延迟100ms后调用函数testenv.mainloop()

This part is important that the reward function design is include, which is as follows

reward = { 1 , if reach the cheese − 1 , if reach the trap or reach the boundary 0 , others \text{reward} = \left \{ \begin{aligned} & 1, \quad \text{if reach the cheese} \\ & -1, \quad \text{if reach the trap or reach the boundary} \\ & 0, \quad \text{others} \end{aligned} \right. reward= 1,if reach the cheese1,if reach the trap or reach the boundary0,others

We need to explan some function of the class Maze.

  • First, the function _build_maze creates the inital maze location.
    In this example we use the left up coordination of each grid as the state of each block.
  • Second, the function get_state converts the coordination of each grid to numerical representation such as ( 1 , 1 ) , ( 1 , 2 ) , ⋯ (1,1),(1,2),\cdots (1,1),(1,2),.
  • Third, the function reset renew the state which means placing the mouse in the original grid.
  • Then, the function step we let the agent interact with envrionment for one step, ang get the reward after the action.
  • Then, the function render controls updating the window.

2.3 Tabular Sarsa Algorithm

import numpy as np
import pandas as pdclass RL():def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):self.actions = actions  # action listself.lr = learning_rateself.gamma = reward_decayself.epsilon = e_greedy # epsilon greedy update policyself.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)def check_state_exist(self, state):if state not in self.q_table.index:# append new state to q table, use the coordinate as the observation# self.q_table = self.q_table.append(       # DataFrame.append is invalid#     pd.Series(#         [0]*len(self.actions),#         index=self.q_table.columns,#         name=state,#     )# )self.q_table = pd.concat([self.q_table,pd.DataFrame(data=np.zeros((1,len(self.actions))),columns = self.q_table.columns,index = [state])])def choose_action(self, observation):"""Use the epsilon-greedy method to update policy"""self.check_state_exist(observation)# action selection# epsilon greedy algorithmif np.random.uniform() < self.epsilon:state_action = self.q_table.loc[observation, :]# some actions may have the same value, randomly choose on in these actions# state_action == np.max(state_action) generate bool mask# choose best actionaction = np.random.choice(state_action[state_action == np.max(state_action)].index)else:# choose random actionaction = np.random.choice(self.actions)return actiondef learn(self, s, a, r, s_):passclass SarsaTable(RL):"""Implement Sarsa algorithm which is on-policy"""def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):super(SarsaTable,self).__init__(actions, learning_rate, reward_decay, e_greedy)def learn(self, s, a, r, s_, a_):self.check_state_exist(s_)q_predict = self.q_table.loc[s, a]if s_ != 'success' :q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminalelse:q_target = r  # next state is terminalself.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

We store the Q-table as a DataFrame of pandas. The explanation of the functions are as follows.

  • First, the function check_state_exist check the existence of one state, if not we append it to the Q-table. This is because once the state-action pair is visited, then we update it into the Q-table.
  • Second, the function choose_action is following the ϵ \epsilon ϵ-greedy algorithm

π ( a ∣ s ) = { 1 − ϵ ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) , for the geedy action ϵ ∣ A ( s ) ∣ , for the other  ∣ A ( s ) ∣ − 1 actions \pi(a|s) = \left \{ \begin{aligned} 1 - \frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A(s)}|-1), & \quad \text{for the geedy action} \\ \frac{\epsilon}{|\mathcal{A}(s)|}, & \quad \text{for the other } |\mathcal{A}(s)|-1 \text{ actions} \end{aligned} \right. π(as)= 1A(s)ϵ(A(s)1),A(s)ϵ,for the geedy actionfor the other A(s)1 actions

  • Third, the function learn is update the q value as Q-learning algorithm purposed, which relays on the sample ( s t , a t , r t + 1 , s t + 1 , a t + 1 ) \textcolor{red}{(s_t,a_t,r_{t+1},s_{t+1},a_{t+1})} (st,at,rt+1,st+1,at+1). The sample denotes current state, current action, immediate reward, next state and next action respectively.

Sarsa : { q t + 1 ( s t , a t ) = q t ( s t , a t ) − α t ( s t , a t ) [ q t ( s t , a t ) − ( r t + 1 + γ q t ( s t + 1 , a t + 1 ) ) ] q t + 1 ( s , a ) = q t ( s , a ) , for all  ( s , a ) ≠ ( s t , a t ) \text{Sarsa} : \left \{ \begin{aligned} \textcolor{red}{q_{t+1}(s_t,a_t)} & \textcolor{red}{= q_t(s_t,a_t) - \alpha_t(s_t,a_t) \Big[q_t(s_t,a_t) - (r_{t+1}+ \gamma \ q_t(s_{t+1},a_{t+1})) \Big]} \\ \textcolor{red}{q_{t+1}(s,a)} & \textcolor{red}{= q_t(s,a)}, \quad \text{for all } (s,a) \ne (s_t,a_t) \end{aligned} \right. Sarsa: qt+1(st,at)qt+1(s,a)=qt(st,at)αt(st,at)[qt(st,at)(rt+1+γ qt(st+1,at+1))]=qt(s,a),for all (s,a)=(st,at)

2.4 Run this main

Run this main script that we can run the all codes.

from maze_env_custom import Maze
from RL_brain import SarsaTableMAX_EPISODE = 30def update():for episode in range(MAX_EPISODE):# initial observation, observation is the rect's coordiante# observation is [x0,y0, x1,y1]observation = env.reset()   # RL choose action based on observation ['up', 'down', 'right', 'left']action = RL.choose_action(str(observation))while True:# fresh envenv.render()# RL take action and get next observation and rewardobservation_, reward, done = env.step(action)action_ = RL.choose_action(str(observation_))# RL learn from this transitionRL.learn(str(observation), action, reward, str(observation_), action_)# swap observationobservation = observation_action = action_# break while loop when end of this episodeif done:break# show q_tableprint(RL.q_table)print('\n')# end of gameprint('game over')env.destroy()if __name__ == "__main__":env = Maze()RL = SarsaTable(env.action_space)env.after(100, update)env.mainloop()

2.5 Check the Q table

After a long run we can check the q-table to judge wheter the learning is reasonable. The q-table is as follows:

                      up      down     right          left
(1.0, 1.0) -6.837352e-02 -0.000135 -0.000266 -2.970185e-02
(2.0, 1.0) -4.901299e-02 -0.000334 -0.000484 -6.039572e-04
(2.0, 2.0) -3.988164e-04 -0.049010 -0.038785 -2.737623e-04
block_1     0.000000e+00  0.049010  0.000000  0.000000e+00
(4.0, 2.0) -2.646359e-04  0.001314 -0.019900 -1.000000e-02
(4.0, 1.0) -4.900994e-02  0.000014 -0.010000 -3.128178e-06
(3.0, 1.0) -2.970450e-02 -0.029433 -0.000516 -2.078845e-04
(1.0, 2.0) -4.933690e-04 -0.000374 -0.000951 -3.940947e-02
block_2    -1.979099e-07  0.000000  0.010000 -1.531800e-07
(1.0, 3.0) -3.525635e-04 -0.000056 -0.010000 -3.940439e-02
(1.0, 4.0) -7.194310e-07 -0.010000  0.000591 -1.990000e-02
(2.0, 4.0) -1.000000e-02 -0.019900  0.012381  0.000000e+00
(3.0, 4.0)  1.654862e-01  0.000000  0.000000  0.000000e+00
(4.0, 4.0)  0.000000e+00  0.000000 -0.010000  0.000000e+00
(4.0, 3.0)  0.000000e+00  0.000000  0.000000  5.851985e-02
success     0.000000e+00  0.000000  0.000000  0.000000e+00

For example, when at the original place if the mouse wants to move up or move left it will reach the boundary and get reward − 1 -1 1. Hence the state value in q-table is minus.


Reference

赵世钰老师的课程
莫烦ReinforcementLearning course

http://www.lryc.cn/news/103317.html

相关文章:

  • 服务调用---------Ribbon和Feign
  • app自动化测试之Appium问题分析及定位
  • 婚庆服务小程序app开发方案详解
  • 集合简述
  • 常见的软件测试面试题汇总
  • 学习笔记|大模型优质Prompt开发与应用课(二)|第二节:超高产文本生成机,传媒营销人必备神器
  • Linux基础-4
  • oracle-创建函数
  • 【Ansible 的脚本 --- playbook 剧本】
  • ubuntu释放缓存
  • 实用调试技巧(1)
  • uniapp:H5定位当前省市区街道信息
  • 自然语言处理从入门到应用——LangChain:提示(Prompts)-[提示模板:部分填充的提示模板和提示合成]
  • 论文笔记--GloVe: Global Vectors for Word Representation
  • day57|● 647. 回文子串 ● 516.最长回文子序列
  • docker compose.yml学习
  • 【业务功能篇55】Springboot+easyPOI 导入导出
  • 对顶堆算法
  • node.js的优点
  • golang编译跨平台
  • 关于Spring的bean的相关注解以及其简单使用方法
  • 【计算机视觉】BLIP:源代码示例demo(含源代码)
  • TWILIGHT靶场详解
  • 【案例】--GPT衍生应用案例
  • Sip网络音频对讲广播模块, sip网络寻呼话筒音频模块
  • leetcode1219. 黄金矿工(java)
  • Svelte框架入门
  • 在linux中进行arm交叉编译体验tiny6410裸机程序开发流程
  • SpringBoot实战(二十三)集成 SkyWalking
  • 深度学习实践——卷积神经网络实践:裂缝识别