import gym
from gym import spaces
import numpy as np
import pygameclass RandomWalk2DEnv(gym.Env):def __init__(self):super(RandomWalk2DEnv, self).__init__()self.x_min, self.x_max = -10, 10 self.y_min, self.y_max = -10, 10 self.observation_space = spaces.Box(np.array([self.x_min, self.y_min]),np.array([self.x_max, self.y_max]),dtype=np.float32)self.action_space = spaces.Box(np.array([-1.0, -1.0]), np.array([1.0, 1.0]), dtype=np.float32)self.state = np.array([0.0, 0.0])self.step_size = 1.0pygame.init()self.screen_size = (800, 800)self.screen = pygame.display.set_mode(self.screen_size)pygame.display.set_caption("Random Walk 2D")self.scale = 40 self.origin = np.array([self.x_max, self.y_max]) self.path_x = []self.path_y = []def reset(self):self.state = np.array([0.0, 0.0])self.path_x = [self.state[0]] self.path_y = [self.state[1]] return self.statedef step(self, action):action = np.clip(action, -1.0, 1.0)dx = action[0] * self.step_size dy = action[1] * self.step_size self.state[0] += dxself.state[1] += dyself.state[0] = np.clip(self.state[0], self.x_min, self.x_max)self.state[1] = np.clip(self.state[1], self.y_min, self.y_max)done = Falseif self.state[0] == self.x_min or self.state[0] == self.x_max or self.state[1] == self.y_min or self.state[1] == self.y_max:done = Trueself.path_x.append(self.state[0])self.path_y.append(self.state[1])distance_from_center = np.linalg.norm(self.state)reward = -distance_from_center return self.state, reward, done, {}def render(self, mode='human'):self.screen.fill((255, 255, 255)) self.draw_grid()for i in range(len(self.path_x) - 1):x1 = int(self.path_x[i] * self.scale + self.screen_size[0] // 2)y1 = int(self.screen_size[1] // 2 - self.path_y[i] * self.scale)x2 = int(self.path_x[i + 1] * self.scale + self.screen_size[0] // 2)y2 = int(self.screen_size[1] // 2 - self.path_y[i + 1] * self.scale)pygame.draw.line(self.screen, (255, 0, 0), (x1, y1), (x2, y2), 2) start_x = int(self.path_x[0] * self.scale + self.screen_size[0] // 2)start_y = int(self.screen_size[1] // 2 - self.path_y[0] * self.scale)pygame.draw.circle(self.screen, (0, 0, 255), (start_x, start_y), 5) current_x = int(self.state[0] * self.scale + self.screen_size[0] // 2)current_y = int(self.screen_size[1] // 2 - self.state[1] * self.scale)pygame.draw.circle(self.screen, (0, 0, 0), (current_x, current_y), 3) pygame.display.flip()for event in pygame.event.get():if event.type == pygame.QUIT:pygame.quit()exit()def draw_grid(self):""" 绘制虚线网格,中心位置为实线 """line_color = (200, 200, 200) center_line_color = (0, 0, 0) line_width = 1dash_length = 10 dash_gap = 5 for x in range(self.x_min, self.x_max + 1):x_pos = int(x * self.scale + self.screen_size[0] // 2)for y in range(self.y_min, self.y_max + 1):y_pos = int(self.screen_size[1] // 2 - y * self.scale)if x == 0:pygame.draw.line(self.screen, center_line_color,(x_pos, self.screen_size[1] // 2 - self.y_max * self.scale),(x_pos, self.screen_size[1] // 2 + self.y_max * self.scale), line_width)elif y == 0:pygame.draw.line(self.screen, center_line_color,(self.screen_size[0] // 2 - self.x_max * self.scale, y_pos),(self.screen_size[0] // 2 + self.x_max * self.scale, y_pos), line_width)elif x % 2 == 0:self.draw_dashed_line(x_pos, self.screen_size[1] // 2 - self.y_max * self.scale,x_pos, self.screen_size[1] // 2 + self.y_max * self.scale,line_color, dash_length, dash_gap)elif y % 2 == 0:self.draw_dashed_line(self.screen_size[0] // 2 - self.x_max * self.scale, y_pos,self.screen_size[0] // 2 + self.x_max * self.scale, y_pos,line_color, dash_length, dash_gap)def draw_dashed_line(self, x1, y1, x2, y2, color, dash_length, dash_gap):""" 绘制虚线 """total_length = np.linalg.norm([x2 - x1, y2 - y1])num_dashes = int(total_length / (dash_length + dash_gap))for i in range(num_dashes):start_x = x1 + (x2 - x1) * (i * (dash_length + dash_gap)) / total_lengthstart_y = y1 + (y2 - y1) * (i * (dash_length + dash_gap)) / total_lengthend_x = x1 + (x2 - x1) * ((i * (dash_length + dash_gap) + dash_length) / total_length)end_y = y1 + (y2 - y1) * ((i * (dash_length + dash_gap) + dash_length) / total_length)pygame.draw.line(self.screen, color, (start_x, start_y), (end_x, end_y), 1)
if __name__ == "__main__":env = RandomWalk2DEnv()env.reset()for _ in range(50):action = env.action_space.sample() state, reward, done, info = env.step(action)env.render()if done:print("智能体触及边缘,回合结束")break
