Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3534e1c

Browse files
author
xyliao
committed
finish gym
1 parent c48c020 commit 3534e1c

File tree

4 files changed

+154
-2
lines changed

4 files changed

+154
-2
lines changed

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Learn Deep Learning with PyTorch
6565

6666
- Chapter 7: 深度强化学习
6767
- [Q Learning](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/q-learning-intro.ipynb)
68-
- Open AI gym
68+
- [Open AI gym](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/open_ai_gym.ipynb)
6969
- [Deep Q-networks](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/dqn.ipynb)
7070

7171
- Chapter 8: PyTorch高级

‎chapter7_RL/dqn.ipynb‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"\n",
1010
"一个非常简单的办法就是使用深度学习来解决这个问题,所以出现了一种新的网络,叫做 Deep Q Networks,将 Q learning 和 神经网络结合在了一起,对于每一个 state,我们都可以使用神经网络来计算对应动作的值,就不在需要建立一张表格,而且网络更新比表格更新更有效率,获取结果也更加高效。\n",
1111
"\n",
12+
"![](https://ws4.sinaimg.cn/large/006tKfTcgy1fni66at6jbj30xo0g1jut.jpg)\n",
13+
"\n",
1214
"下面我们使用 open ai gym 环境中的 CartPole 来尝试实现一个简单的 DQN。"
1315
]
1416
},
@@ -400,7 +402,9 @@
400402
"cell_type": "markdown",
401403
"metadata": {},
402404
"source": [
403-
"我们画出 reward 的曲线,可以发现奖励在不断变多,说明我们的 agent 学得越来越好,同时我们也可以实实在在地看到 agent 玩得怎么样,gym 提供了可视化的过程,但是 notebook 里面没有办法显示,我们可以使用运行 `dqn.py` 来看到 agent 玩的可视化视频。"
405+
"我们画出 reward 的曲线,可以发现奖励在不断变多,说明我们的 agent 学得越来越好,同时我们也可以实实在在地看到 agent 玩得怎么样,gym 提供了可视化的过程,但是 notebook 里面没有办法显示,我们可以使用运行 `dqn.py` 来看到 agent 玩的可视化视频。\n",
406+
"\n",
407+
"另外,我们这里只使用了简单的多层神经网络来作为 dqn 的网络结构,网络的输入是杆的位置信息和角度等等,我们当然可以使用更加一般的输入,比如说每个状态都是一个图片的输入,那么这种方式更具有一般性,实现上几乎是一模一样,只需要改一改网络结构,同时 gym 中也可以得到每个屏幕的输出,具体可以看看 pytorch 的[官方例子](http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#)。"
404408
]
405409
}
406410
],

‎chapter7_RL/mount-car.py‎

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
3+
import gym
4+
5+
n_states = 40 # 取样 40 个状态
6+
iter_max = 10000
7+
8+
initial_lr = 1.0 # Learning rate
9+
min_lr = 0.003
10+
gamma = 1.0
11+
t_max = 10000
12+
eps = 0.02
13+
14+
15+
def run_episode(env, policy=None, render=False):
16+
obs = env.reset()
17+
total_reward = 0
18+
step_idx = 0
19+
for _ in range(t_max):
20+
if render:
21+
env.render()
22+
if policy is None: # 如果没有策略,就随机取样
23+
action = env.action_space.sample()
24+
else:
25+
a, b = obs_to_state(env, obs)
26+
action = policy[a][b]
27+
obs, reward, done, _ = env.step(action)
28+
total_reward += gamma ** step_idx * reward
29+
step_idx += 1
30+
if done:
31+
break
32+
return total_reward
33+
34+
35+
def obs_to_state(env, obs):
36+
"""
37+
将观察的连续环境映射到离散的输入的状态
38+
"""
39+
env_low = env.observation_space.low
40+
env_high = env.observation_space.high
41+
env_dx = (env_high - env_low) / n_states
42+
a = int((obs[0] - env_low[0]) / env_dx[0])
43+
b = int((obs[1] - env_low[1]) / env_dx[1])
44+
return a, b
45+
46+
47+
if __name__ == '__main__':
48+
env_name = 'MountainCar-v0'
49+
env = gym.make(env_name)
50+
env.seed(0)
51+
np.random.seed(0)
52+
print('----- using Q Learning -----')
53+
q_table = np.zeros((n_states, n_states, 3))
54+
for i in range(iter_max):
55+
obs = env.reset()
56+
total_reward = 0
57+
## eta: 每一步学习率都不断减小
58+
eta = max(min_lr, initial_lr * (0.85 ** (i // 100)))
59+
for j in range(t_max):
60+
x, y = obs_to_state(env, obs)
61+
if np.random.uniform(0, 1) < eps: # greedy 贪心算法
62+
action = np.random.choice(env.action_space.n)
63+
else:
64+
logits = q_table[x, y, :]
65+
logits_exp = np.exp(logits)
66+
probs = logits_exp / np.sum(logits_exp) # 算出三个动作的概率
67+
action = np.random.choice(env.action_space.n, p=probs) # 依概率来选择动作
68+
obs, reward, done, _ = env.step(action)
69+
total_reward += reward
70+
# 更新 q 表
71+
x_, y_ = obs_to_state(env, obs)
72+
q_table[x, y, action] = q_table[x, y, action] + eta * (
73+
reward + gamma * np.max(q_table[x_, y_, :]) -
74+
q_table[x, y, action])
75+
if done:
76+
break
77+
if i % 100 == 0:
78+
print('Iteration #%d -- Total reward = %d.' % (i + 1,
79+
total_reward))
80+
solution_policy = np.argmax(q_table, axis=2) # 在 q 表中每个状态下都取最大的值得动作
81+
solution_policy_scores = [
82+
run_episode(env, solution_policy, False) for _ in range(100)
83+
]
84+
print("Average score of solution = ", np.mean(solution_policy_scores))
85+
# Animate it
86+
run_episode(env, solution_policy, True)

‎chapter7_RL/open_ai_gym.ipynb‎

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"collapsed": true
7+
},
8+
"source": [
9+
"# Gym 介绍\n",
10+
"前面我们简单的介绍了强化学习的例子,从这个例子可以发现,构建强化学习的环境非常麻烦,需要耗费我们大量的时间,这个时候我们可以使用一个开源的工具,叫做 gym,是由 open ai 开发的。\n",
11+
"\n",
12+
"在这个库中从简单的走格子到毁灭战士,提供了各种各样的游戏环境可以让大家放自己的 AI 进去玩耍。取名叫 gym 也很有意思,可以想象一群 AI 在健身房里各种锻炼,磨练技术。\n",
13+
"\n",
14+
"使用起来也非常方便,首先在终端内输入如下代码进行安装。\n",
15+
"\n",
16+
"```\n",
17+
"# Github源\n",
18+
"git clone https://github.com/openai/gym\n",
19+
"cd gym\n",
20+
"pip install -e .[all]\n",
21+
"\n",
22+
"# 直接下载gym包\n",
23+
"pip install gym[all]\n",
24+
"```\n",
25+
"\n",
26+
"我们可以访问这个页面看到 gym 所[包含的环境和介绍](https://github.com/openai/gym/wiki)。"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"在上面的环境页面,可以 gym 内置了很多环境,我们可以使用前面讲过的 q learning 尝试一个 gym 中的小例子,[mountain car](https://github.com/openai/gym/wiki/MountainCar-v0)。在 mounttain car,我们能够观察到环境中小车的位置,也就是坐标,我们能够采取的动作是向左或者向右。\n",
34+
"\n",
35+
"为了使用 q learning,我们必须要建立 q 表,而这里的状态空间是连续不可数的,所以我们需要离散化连续空间,将 x 坐标和 y 坐标都平均分成很多份,具体的实现可以运行 `mount-car.py` 看看结果。\n",
36+
"\n",
37+
"如果运行完之后,可以看到 q 表的收敛非常慢,reward 一直都很难变化,我们需要很久才能将小车推到终点,这个时候我们需要一个更加强大的武器,那就 deep q network。"
38+
]
39+
}
40+
],
41+
"metadata": {
42+
"kernelspec": {
43+
"display_name": "Python 3",
44+
"language": "python",
45+
"name": "python3"
46+
},
47+
"language_info": {
48+
"codemirror_mode": {
49+
"name": "ipython",
50+
"version": 3
51+
},
52+
"file_extension": ".py",
53+
"mimetype": "text/x-python",
54+
"name": "python",
55+
"nbconvert_exporter": "python",
56+
"pygments_lexer": "ipython3",
57+
"version": "3.6.3"
58+
}
59+
},
60+
"nbformat": 4,
61+
"nbformat_minor": 2
62+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /