Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e8c3d8e

Browse files
modify code
1 parent 45fecd3 commit e8c3d8e

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

‎seg_lstm.py‎

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self):
2323
self.tag_count = 4
2424
self.concat_embed_size = self.window_size * self.embed_size
2525
self.vocab_size = constant.VOCAB_SIZE
26-
self.alpha = 0.05
26+
self.alpha = 0.1
2727
self.lam = 0.0001
2828
self.eta = 0.02
2929
self.dropout_rate = 0.2
@@ -36,13 +36,13 @@ def __init__(self):
3636
self.sess = tf.Session()
3737
self.optimizer = tf.train.GradientDescentOptimizer(self.alpha)
3838
self.x = tf.placeholder(self.dtype, shape=[1, None, self.concat_embed_size])
39-
self.embeddings = tf.Variable(
40-
tf.truncated_normal([self.vocab_size, self.embed_size], stddev=-1.0 / math.sqrt(self.embed_size),
41-
dtype=self.dtype), dtype=self.dtype, name='embeddings')
39+
#self.embeddings = tf.Variable(
40+
# tf.truncated_normal([self.vocab_size, self.embed_size], stddev=-1.0 / math.sqrt(self.embed_size),
41+
# dtype=self.dtype), dtype=self.dtype, name='embeddings')
42+
self.embeddings = tf.Variable(np.load('corpus/lstm/embeddings.npy'), dtype=self.dtype, name='embeddings')
4243
self.w = tf.Variable(
4344
tf.truncated_normal([self.tags_count, self.hidden_units], stddev=1.0 / math.sqrt(self.concat_embed_size),
44-
dtype=self.dtype),
45-
dtype=self.dtype, name='w')
45+
dtype=self.dtype), dtype=self.dtype, name='w')
4646
self.b = tf.Variable(tf.zeros([self.tag_count, 1], dtype=self.dtype), dtype=self.dtype, name='b')
4747
self.A = tf.Variable(tf.random_uniform([self.tag_count, self.tag_count], -0.05, 0.05, dtype=self.dtype),
4848
dtype=self.dtype, name='A')
@@ -53,17 +53,18 @@ def __init__(self):
5353
self.update_A_op = self.A.assign((1 - self.lam) * (tf.add(self.A, self.alpha * self.Ap)))
5454
self.update_init_A_op = self.init_A.assign((1 - self.lam) * (tf.add(self.init_A, self.alpha * self.init_Ap)))
5555
self.sentence_holder = tf.placeholder(tf.int32, shape=[None, self.window_size])
56-
self.lookup_op = tf.nn.embedding_lookup(self.embeddings, self.sentence_holder)
56+
self.lookup_op = tf.nn.embedding_lookup(self.embeddings, self.sentence_holder).reshape([-1,self.concat_embed_size])
5757
self.indices = tf.placeholder(tf.int32, shape=[None, 2])
5858
self.shape = tf.placeholder(tf.int32, shape=[2])
5959
self.values = tf.placeholder(self.dtype, shape=[None])
6060
self.map_matrix_op = tf.sparse_to_dense(self.indices, self.shape, self.values, validate_indices=False)
6161
self.map_matrix = tf.placeholder(self.dtype, shape=[self.tag_count, None])
6262
self.lstm = tf.contrib.rnn.LSTMCell(self.hidden_units)
6363
self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
64+
#self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
6465
tf.global_variables_initializer().run(session=self.sess)
6566
self.word_scores = tf.matmul(self.w, tf.transpose(self.lstm_output[0])) + self.b
66-
self.loss_scores = tf.multiply(self.map_matrix, self.word_scores)
67+
self.loss_scores = tf.reduce_sum(tf.multiply(self.map_matrix, self.word_scores),0)
6768
self.loss = tf.reduce_sum(self.loss_scores)
6869
self.lstm_variable = [v for v in tf.global_variables() if v.name.startswith('rnn')]
6970
self.params = [self.w, self.b] + self.lstm_variable
@@ -72,9 +73,8 @@ def __init__(self):
7273
self.embedp = tf.placeholder(self.dtype, shape=[None, self.embed_size])
7374
self.embed_index = tf.placeholder(tf.int32, shape=[None])
7475
self.update_embed_op = tf.scatter_update(self.embeddings, self.embed_index, self.embedp)
75-
self.sentence_length = 1
7676
self.sentence_index = 0
77-
self.grad_embed = tf.gradients(self.loss_scores[:, self.sentence_index], self.x)
77+
self.grad_embed = tf.gradients(self.loss_scores[self.sentence_index], self.x)
7878
self.saver = tf.train.Saver(self.params + [self.embeddings, self.A, self.init_A], max_to_keep=100)
7979

8080
def model(self, embeds):
@@ -123,6 +123,7 @@ def train_sentence(self, sentence, tags, length):
123123
feed_dict={self.x: np.expand_dims(sentence_embeds, 0), self.map_matrix: sentence_matrix})
124124
self.sess.run(self.regularization)
125125

126+
'''
126127
# 更新词向量
127128
self.sentence_length = length
128129
@@ -136,7 +137,7 @@ def train_sentence(self, sentence, tags, length):
136137
feed_dict={
137138
self.embedp: sentence_update_embed.reshape([self.window_size, self.embed_size]),
138139
self.embed_index: sentence[index]})
139-
140+
'''
140141
# 更新转移矩阵
141142
A_update, init_A_update, update_init = self.gen_update_A(tags, current_tags)
142143
if update_init:

‎test.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_seg_dnn():
3131
#print(seg.seg('小明来自南京师范大学'))
3232
#print(seg.seg('小明是上海理工大学的学生'))
3333
print(seg.seg('小明来自南京师范大学'))
34+
print(seg.seg('小明是上海理工大学的学生'))
3435
test(seg,'tmp/lstm-model0.ckpt')
3536
# print(seq)
3637
# cal_val(seq)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /