Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c8ac1e4

Browse files
authored
Merge pull request #51 from zhengjxu/master
the dropout probability should be different between train and inference
2 parents 909a8b7 + 43d67e9 commit c8ac1e4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

‎10 - RNN/02 - Autocomplete.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ def make_batch(seq_data):
7070
# 기존처럼 one-hot 인코딩을 사용한다면 입력값의 형태는 [None, n_class] 여야합니다.
7171
Y = tf.placeholder(tf.int32, [None])
7272

73+
# dropout prob for RNN
74+
keep_prob = tf.placeholder(tf.float32, [])
75+
7376
W = tf.Variable(tf.random_normal([n_hidden, n_class]))
7477
b = tf.Variable(tf.random_normal([n_class]))
7578

7679
# RNN 셀을 생성합니다.
7780
cell1 = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
7881
# 과적합 방지를 위한 Dropout 기법을 사용합니다.
79-
cell1 = tf.nn.rnn_cell.DropoutWrapper(cell1, output_keep_prob=0.5)
82+
cell1 = tf.nn.rnn_cell.DropoutWrapper(cell1, output_keep_prob=keep_prob)
8083
# 여러개의 셀을 조합해서 사용하기 위해 셀을 추가로 생성합니다.
8184
cell2 = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
8285

@@ -108,7 +111,9 @@ def make_batch(seq_data):
108111

109112
for epoch in range(total_epoch):
110113
_, loss = sess.run([optimizer, cost],
111-
feed_dict={X: input_batch, Y: target_batch})
114+
feed_dict={X: input_batch,
115+
Y: target_batch,
116+
keep_prob: 0.5})
112117

113118
print('Epoch:', '%04d' % (epoch + 1),
114119
'cost =', '{:.6f}'.format(loss))
@@ -127,7 +132,9 @@ def make_batch(seq_data):
127132
input_batch, target_batch = make_batch(seq_data)
128133

129134
predict, accuracy_val = sess.run([prediction, accuracy],
130-
feed_dict={X: input_batch, Y: target_batch})
135+
feed_dict={X: input_batch,
136+
Y: target_batch,
137+
keep_prob:1})
131138

132139
predict_words = []
133140
for idx, val in enumerate(seq_data):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /