Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ac0e2e5

Browse files
Create windy_gridworld.py
1 parent 7a3cde6 commit ac0e2e5

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import gym
2+
import numpy as np
3+
import sys
4+
from gym.envs.toy_text import discrete
5+
6+
UP = 0
7+
RIGHT = 1
8+
DOWN = 2
9+
LEFT = 3
10+
11+
class WindyGridworldEnv(discrete.DiscreteEnv):
12+
13+
metadata = {'render.modes': ['human', 'ansi']}
14+
15+
def _limit_coordinates(self, coord):
16+
coord[0] = min(coord[0], self.shape[0] - 1)
17+
coord[0] = max(coord[0], 0)
18+
coord[1] = min(coord[1], self.shape[1] - 1)
19+
coord[1] = max(coord[1], 0)
20+
return coord
21+
22+
def _calculate_transition_prob(self, current, delta, winds):
23+
new_position = np.array(current) + np.array(delta) + np.array([-1, 0]) * winds[tuple(current)]
24+
new_position = self._limit_coordinates(new_position).astype(int)
25+
new_state = np.ravel_multi_index(tuple(new_position), self.shape)
26+
is_done = tuple(new_position) == (3, 7)
27+
return [(1.0, new_state, -1.0, is_done)]
28+
29+
def __init__(self):
30+
self.shape = (7, 10)
31+
32+
nS = np.prod(self.shape)
33+
nA = 4
34+
35+
# Wind strength
36+
winds = np.zeros(self.shape)
37+
winds[:,[3,4,5,8]] = 1
38+
winds[:,[6,7]] = 2
39+
40+
# Calculate transition probabilities
41+
P = {}
42+
for s in range(nS):
43+
position = np.unravel_index(s, self.shape)
44+
P[s] = { a : [] for a in range(nA) }
45+
P[s][UP] = self._calculate_transition_prob(position, [-1, 0], winds)
46+
P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1], winds)
47+
P[s][DOWN] = self._calculate_transition_prob(position, [1, 0], winds)
48+
P[s][LEFT] = self._calculate_transition_prob(position, [0, -1], winds)
49+
50+
# We always start in state (3, 0)
51+
isd = np.zeros(nS)
52+
isd[np.ravel_multi_index((3,0), self.shape)] = 1.0
53+
54+
super(WindyGridworldEnv, self).__init__(nS, nA, P, isd)
55+
56+
def render(self, mode='human', close=False):
57+
self._render(mode, close)
58+
59+
def _render(self, mode='human', close=False):
60+
if close:
61+
return
62+
63+
outfile = StringIO() if mode == 'ansi' else sys.stdout
64+
65+
for s in range(self.nS):
66+
position = np.unravel_index(s, self.shape)
67+
# print(self.s)
68+
if self.s == s:
69+
output = " x "
70+
elif position == (3,7):
71+
output = " T "
72+
else:
73+
output = " o "
74+
75+
if position[1] == 0:
76+
output = output.lstrip()
77+
if position[1] == self.shape[1] - 1:
78+
output = output.rstrip()
79+
output += "\n"
80+
81+
outfile.write(output)
82+
outfile.write("\n")

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /