Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6773d74

Browse files
+
1 parent 0be5d22 commit 6773d74

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

‎16-19/dp.py‎

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66

77
class Network():
8-
def __init__(self, train_batch_size, test_batch_size, pooling_scale):
8+
def __init__(self, train_batch_size, test_batch_size, pooling_scale,
9+
optimizeMethod='adam'):
910
'''
1011
@num_hidden: 隐藏层的节点数量
1112
@batch_size:因为我们要节省内存,所以分批处理数据。每一批的数据量。
1213
'''
14+
self.optimizeMethod = optimizeMethod;
15+
1316
self.train_batch_size = train_batch_size
1417
self.test_batch_size = test_batch_size
1518

@@ -148,9 +151,32 @@ def model(data_flow, train=True):
148151
self.loss += self.apply_regularization(_lambda=5e-4)
149152
self.train_summaries.append(tf.scalar_summary('Loss', self.loss))
150153

154+
# learning rate decay
155+
global_step = tf.Variable(0)
156+
lr = 0.001
157+
dr = 0.99
158+
learning_rate = tf.train.exponential_decay(
159+
learning_rate=lr,
160+
global_step=global_step*self.train_batch_size,
161+
decay_steps=100,
162+
decay_rate=dr,
163+
staircase=True
164+
)
165+
151166
# Optimizer.
152167
with tf.name_scope('optimizer'):
153-
self.optimizer = tf.train.GradientDescentOptimizer(0.0001).minimize(self.loss)
168+
if(self.optimizeMethod=='gradient'):
169+
self.optimizer = tf.train \
170+
.GradientDescentOptimizer(learning_rate) \
171+
.minimize(self.loss)
172+
elif(self.optimizeMethod=='momentum'):
173+
self.optimizer = tf.train \
174+
.MomentumOptimizer(learning_rate, 0.5) \
175+
.minimize(self.loss)
176+
elif(self.optimizeMethod=='adam'):
177+
self.optimizer = tf.train \
178+
.AdamOptimizer(learning_rate) \
179+
.minimize(self.loss)
154180

155181
# Predictions for the training, validation, and test data.
156182
with tf.name_scope('train'):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /