Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d6305bb

Browse files
committed
코드 정리
1 parent cd03c21 commit d6305bb

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

‎10 - DQN/model.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def __init__(self, session, width, height, n_action):
3636
# 손실값을 계산하는데 사용할 입력값입니다. train 함수를 참고하세요.
3737
self.input_Y = tf.placeholder(tf.float32, [None])
3838

39-
self.Q_value = self._build_network('main')
39+
self.Q = self._build_network('main')
4040
self.cost, self.train_op = self._build_op()
4141

4242
# 학습을 더 잘 되게 하기 위해,
4343
# 손실값 계산을 위해 사용하는 타겟(실측값)의 Q value를 계산하는 네트웍을 따로 만들어서 사용합니다
44-
self.target_Q_value = self._build_network('target')
44+
self.target_Q = self._build_network('target')
4545

4646
def _build_network(self, name):
4747
with tf.variable_scope(name):
@@ -50,15 +50,15 @@ def _build_network(self, name):
5050
model = tf.contrib.layers.flatten(model)
5151
model = tf.layers.dense(model, 512, activation=tf.nn.relu)
5252

53-
Q_value = tf.layers.dense(model, self.n_action, activation=None)
53+
Q = tf.layers.dense(model, self.n_action, activation=None)
5454

55-
return Q_value
55+
return Q
5656

5757
def _build_op(self):
5858
# DQN 의 손실 함수를 구성하는 부분입니다. 다음 수식을 참고하세요.
5959
# Perform a gradient descent step on (y_j-Q(ð_j,a_j;θ))^2
6060
one_hot = tf.one_hot(self.input_A, self.n_action, 1.0, 0.0)
61-
Q_value = tf.reduce_sum(tf.multiply(self.Q_value, one_hot), axis=1)
61+
Q_value = tf.reduce_sum(tf.multiply(self.Q, one_hot), axis=1)
6262
cost = tf.reduce_mean(tf.square(self.input_Y - Q_value))
6363
train_op = tf.train.AdamOptimizer(1e-6).minimize(cost)
6464

@@ -78,7 +78,7 @@ def update_target_network(self):
7878
self.session.run(copy_op)
7979

8080
def get_action(self):
81-
Q_value = self.session.run(self.Q_value,
81+
Q_value = self.session.run(self.Q,
8282
feed_dict={self.input_X: [self.state]})
8383

8484
action = np.argmax(Q_value[0])
@@ -124,10 +124,9 @@ def train(self):
124124
# 게임 플레이를 저장한 메모리에서 배치 사이즈만큼을 샘플링하여 가져옵니다.
125125
state, next_state, action, reward, terminal = self._sample_memory()
126126

127-
# 학습시 다음 상태를 만들어 낸 Q value를 입력값으로
128-
# 타겟 네트웍의 Q value를 실측값으로하여 학습합니다
129-
Q_value = self.session.run(self.target_Q_value,
130-
feed_dict={self.input_X: next_state})
127+
# 학습시 다음 상태를 타겟 네트웍에 넣어 target Q value를 구합니다
128+
target_Q_value = self.session.run(self.target_Q,
129+
feed_dict={self.input_X: next_state})
131130

132131
# DQN 의 손실 함수에 사용할 핵심적인 값을 계산하는 부분입니다. 다음 수식을 참고하세요.
133132
# if episode is terminates at step j+1 then r_j
@@ -138,7 +137,7 @@ def train(self):
138137
if terminal[i]:
139138
Y.append(reward[i])
140139
else:
141-
Y.append(reward[i] + self.GAMMA * np.max(Q_value[i]))
140+
Y.append(reward[i] + self.GAMMA * np.max(target_Q_value[i]))
142141

143142
self.session.run(self.train_op,
144143
feed_dict={

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /