Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 69aa96b

Browse files
Add files via upload
1 parent ac0e2e5 commit 69aa96b

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import gym
2+
import itertools
3+
import matplotlib
4+
import matplotlib.style
5+
import numpy as np
6+
import pandas as pd
7+
import sys
8+
9+
10+
from collections import defaultdict
11+
from windy_gridworld import WindyGridworldEnv
12+
import plotting
13+
14+
matplotlib.style.use('ggplot')
15+
16+
# transition probability {0,1}
17+
# Reward {-1,0,1}
18+
# States {S0,S1,S2,S3,S4,S5,..........SN}
19+
# Acation {left, Right,Down, Up}
20+
21+
22+
# Question
23+
"""
24+
1. what are the Good State (Value Funcation)
25+
2. What are Good State and Action pair(Q-Value Funcation)
26+
27+
Take these Files from Github
28+
https://github.com/reddyprasade/Machine-Learning-with-Scikit-Learn-Python-3.x/tree/master/Reinforcement%20Learning/Q-Learning
29+
30+
1. windy_gridworld.py
31+
2. plotting.py
32+
"""
33+
34+
env = WindyGridworldEnv()
35+
36+
37+
## Make the $\epsilon$-greedy policy.
38+
39+
40+
def createEpsilonGreedyPolicy(Q, epsilon, num_actions):
41+
"""
42+
Creates an epsilon-greedy policy based
43+
on a given Q-function and epsilon.
44+
45+
Returns a function that takes the state
46+
as an input and returns the probabilities
47+
for each action in the form of a numpy array
48+
of length of the action space(set of possible actions).
49+
"""
50+
def policyFunction(state):
51+
52+
Action_probabilities = np.ones(num_actions,
53+
dtype = float) * epsilon / num_actions
54+
55+
best_action = np.argmax(Q[state]) # for Which State which action is best
56+
Action_probabilities[best_action] += (1.0 - epsilon)
57+
return Action_probabilities
58+
59+
return policyFunction
60+
61+
62+
63+
64+
65+
66+
# Build Q-Learning Model.
67+
68+
def qLearning(env, num_episodes, discount_factor = 1.0,
69+
alpha = 0.6, epsilon = 0.1):
70+
"""
71+
Q-Learning algorithm: Off-policy TD control.
72+
Finds the optimal greedy policy while improving
73+
following an epsilon-greedy policy"""
74+
75+
# Action value function
76+
# A nested dictionary that maps
77+
# state -> (action -> action-value).
78+
Q = defaultdict(lambda: np.zeros(env.action_space.n))
79+
80+
# Keeps track of useful statistics
81+
stats = plotting.EpisodeStats(
82+
episode_lengths = np.zeros(num_episodes),
83+
episode_rewards = np.zeros(num_episodes))
84+
85+
# Create an epsilon greedy policy function
86+
# appropriately for environment action space
87+
policy = createEpsilonGreedyPolicy(Q, epsilon, env.action_space.n)
88+
89+
# For every episode
90+
for ith_episode in range(num_episodes):
91+
92+
# Reset the environment and pick the first action
93+
state = env.reset()
94+
95+
for t in itertools.count():
96+
97+
# get probabilities of all actions from current state
98+
action_probabilities = policy(state)
99+
100+
# choose action according to
101+
# the probability distribution
102+
action = np.random.choice(np.arange(
103+
len(action_probabilities)),
104+
p = action_probabilities)
105+
106+
# take action and get reward, transit to next state
107+
next_state, reward, done, _ = env.step(action)
108+
109+
# Update statistics
110+
stats.episode_rewards[ith_episode] += reward
111+
stats.episode_lengths[ith_episode] = t
112+
113+
# TD Update
114+
best_next_action = np.argmax(Q[next_state])
115+
td_target = reward + discount_factor * Q[next_state][best_next_action]
116+
td_delta = td_target - Q[state][action]
117+
Q[state][action] += alpha * td_delta
118+
119+
# done is True if episode terminated
120+
if done:
121+
break
122+
123+
state = next_state
124+
125+
return Q, stats
126+
127+
128+
# Now i want to train the model
129+
130+
Q,stats = qLearning(env,5)
131+
132+
# Plot important statistics.
133+
plotting.plot_episode_stats(stats)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /