Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a783f6b

Browse files
add embedding layer implementation
1 parent f1ba069 commit a783f6b

File tree

1 file changed

+47
-38
lines changed

1 file changed

+47
-38
lines changed

‎seg_lstm.py‎

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,50 +35,57 @@ def __init__(self):
3535
# 模型定义和初始化
3636
self.sess = tf.Session()
3737
self.optimizer = tf.train.GradientDescentOptimizer(self.alpha)
38+
# self.optimizer = tf.train.AdamOptimizer(self.alpha)
3839
self.x = tf.placeholder(self.dtype, shape=[1, None, self.concat_embed_size])
39-
#self.embeddings = tf.Variable(
40+
#self.embeddings = tf.Variable(
4041
# tf.truncated_normal([self.vocab_size, self.embed_size], stddev=-1.0 / math.sqrt(self.embed_size),
41-
# dtype=self.dtype), dtype=self.dtype, name='embeddings')
42-
self.embeddings = tf.Variable(np.load('corpus/lstm/embeddings.npy'), dtype=self.dtype, name='embeddings')
42+
# dtype=self.dtype), name='embeddings')
43+
self.embeddings = tf.Variable(
44+
tf.random_uniform([self.vocab_size, self.embed_size], -1.0 / math.sqrt(self.embed_size),
45+
1.0 / math.sqrt(self.embed_size), dtype=self.dtype), name='embeddings')
46+
# self.embeddings = tf.Variable(np.load('corpus/lstm/embeddings.npy'), dtype=self.dtype, name='embeddings')
4347
self.w = tf.Variable(
4448
tf.truncated_normal([self.tags_count, self.hidden_units], stddev=1.0 / math.sqrt(self.concat_embed_size),
45-
dtype=self.dtype), dtype=self.dtype, name='w')
46-
self.b = tf.Variable(tf.zeros([self.tag_count, 1], dtype=self.dtype), dtype=self.dtype, name='b')
47-
self.A = tf.Variable(tf.random_uniform([self.tag_count, self.tag_count], -0.05, 0.05, dtype=self.dtype),
48-
dtype=self.dtype, name='A')
49+
dtype=self.dtype), name='w')
50+
self.b = tf.Variable(tf.zeros([self.tag_count, 1], dtype=self.dtype), name='b')
51+
self.A = tf.Variable(tf.random_uniform([self.tag_count, self.tag_count], -0.05, 0.05, dtype=self.dtype), name='A')
4952
self.Ap = tf.placeholder(self.dtype, shape=self.A.get_shape())
50-
self.init_A = tf.Variable(tf.random_uniform([self.tag_count], -0.05, 0.05, dtype=self.dtype), dtype=self.dtype,
51-
name='init_A')
53+
self.init_A = tf.Variable(tf.random_uniform([self.tag_count], -0.05, 0.05, dtype=self.dtype), name='init_A')
5254
self.init_Ap = tf.placeholder(self.dtype, shape=self.init_A.get_shape())
53-
self.update_A_op = self.A.assign((1 - self.lam) * (tf.add(self.A, self.alpha * self.Ap)))
54-
self.update_init_A_op = self.init_A.assign((1 - self.lam) * (tf.add(self.init_A, self.alpha * self.init_Ap)))
55+
self.update_A_op = self.A.assign(tf.add((1 - self.alpha * self.lam) * self.A, self.alpha * self.Ap))
56+
self.update_init_A_op = self.init_A.assign(
57+
tf.add((1 - self.alpha * self.lam) * self.init_A, self.alpha * self.init_Ap))
5558
self.sentence_holder = tf.placeholder(tf.int32, shape=[None, self.window_size])
56-
self.lookup_op = tf.nn.embedding_lookup(self.embeddings, self.sentence_holder).reshape([-1,self.concat_embed_size])
59+
self.lookup_op = tf.reshape(tf.nn.embedding_lookup(self.embeddings, self.sentence_holder),
60+
[-1, 1, self.concat_embed_size])
5761
self.indices = tf.placeholder(tf.int32, shape=[None, 2])
5862
self.shape = tf.placeholder(tf.int32, shape=[2])
5963
self.values = tf.placeholder(self.dtype, shape=[None])
6064
self.map_matrix_op = tf.sparse_to_dense(self.indices, self.shape, self.values, validate_indices=False)
6165
self.map_matrix = tf.placeholder(self.dtype, shape=[self.tag_count, None])
6266
self.lstm = tf.contrib.rnn.LSTMCell(self.hidden_units)
63-
self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
64-
#self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
67+
# self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
68+
self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.lookup_op, dtype=self.dtype,
69+
time_major=True)
6570
tf.global_variables_initializer().run(session=self.sess)
66-
self.word_scores = tf.matmul(self.w, tf.transpose(self.lstm_output[0])) + self.b
67-
self.loss_scores = tf.reduce_sum(tf.multiply(self.map_matrix, self.word_scores),0)
68-
self.loss = tf.reduce_sum(self.loss_scores)
71+
self.word_scores = tf.matmul(self.w, tf.transpose(self.lstm_output[:, -1, :])) + self.b
72+
self.loss_scores = tf.reduce_sum(tf.multiply(self.map_matrix, self.word_scores), 0)
6973
self.lstm_variable = [v for v in tf.global_variables() if v.name.startswith('rnn')]
7074
self.params = [self.w, self.b] + self.lstm_variable
75+
self.loss = tf.reduce_sum(self.loss_scores) + tf.contrib.layers.apply_regularization(
76+
tf.contrib.layers.l2_regularizer(self.lam), self.params + [self.embeddings])
7177
self.regularization = list(map(lambda p: tf.assign_sub(p, self.lam * p), self.params))
72-
self.train = self.optimizer.minimize(self.loss, var_list=self.params)
78+
self.train = self.optimizer.minimize(self.loss, var_list=self.params + [self.embeddings])
79+
# tf.global_variables_initializer().run(session=self.sess)
7380
self.embedp = tf.placeholder(self.dtype, shape=[None, self.embed_size])
7481
self.embed_index = tf.placeholder(tf.int32, shape=[None])
7582
self.update_embed_op = tf.scatter_update(self.embeddings, self.embed_index, self.embedp)
7683
self.sentence_index = 0
7784
self.grad_embed = tf.gradients(self.loss_scores[self.sentence_index], self.x)
7885
self.saver = tf.train.Saver(self.params + [self.embeddings, self.A, self.init_A], max_to_keep=100)
7986

80-
def model(self, embeds):
81-
scores = self.sess.run(self.word_scores, feed_dict={self.x: np.expand_dims(embeds, 0)})
87+
def model(self, sentence):
88+
scores = self.sess.run(self.word_scores, feed_dict={self.sentence_holder: sentence})
8289
path = self.viterbi(scores, self.A.eval(self.sess), self.init_A.eval(self.sess))
8390
return path
8491

@@ -88,18 +95,19 @@ def train_exe(self):
8895
for i in range(10):
8996
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
9097
self.train_sentence(sentence, tags, len(tags))
91-
if sentence_index >0 and sentence_index % 500 == 0:
98+
if sentence_index >0 and sentence_index % 1000 == 0:
9299
print(sentence_index)
93100
print(time.time() - last_time)
94101
last_time = time.time()
95-
print(self.cal_loss(sentence_index-500,sentence_index))
102+
# print(self.cal_loss(sentence_index-500,sentence_index))
96103
print(self.sess.run(self.init_A))
97104
self.saver.save(self.sess, 'tmp/lstm-model%d.ckpt' % i)
98105

99106
def train_sentence(self, sentence, tags, length):
100-
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
101-
[length, self.concat_embed_size])
102-
current_tags = self.model(sentence_embeds)
107+
# sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
108+
# [length, self.concat_embed_size])
109+
# print(sentence_embeds.shape)
110+
current_tags = self.model(sentence)
103111
diff_tags = np.subtract(tags, current_tags)
104112
update_index = np.where(diff_tags != 0)[0]
105113
update_length = len(update_index)
@@ -119,9 +127,10 @@ def train_sentence(self, sentence, tags, length):
119127
self.values: sparse_values})
120128

121129
# 更新参数
122-
self.sess.run(self.train,
123-
feed_dict={self.x: np.expand_dims(sentence_embeds, 0), self.map_matrix: sentence_matrix})
124-
self.sess.run(self.regularization)
130+
# self.sess.run(self.train,
131+
# feed_dict={self.x: np.expand_dims(sentence_embeds, 0), self.map_matrix: sentence_matrix})
132+
self.sess.run(self.train, feed_dict={self.sentence_holder: sentence, self.map_matrix: sentence_matrix})
133+
# self.sess.run(self.regularization)
125134

126135
'''
127136
# 更新词向量
@@ -144,10 +153,9 @@ def train_sentence(self, sentence, tags, length):
144153
self.sess.run(self.update_init_A_op, feed_dict={self.init_Ap: init_A_update})
145154
self.sess.run(self.update_A_op, {self.Ap: A_update})
146155

147-
@staticmethod
148-
def gen_update_A(correct_tags, current_tags):
149-
A_update = np.zeros([4, 4], dtype=np.float32)
150-
init_A_update = np.zeros([4], dtype=np.float32)
156+
def gen_update_A(self, correct_tags, current_tags):
157+
A_update = np.zeros([self.tag_count, self.tag_count], dtype=np.float32)
158+
init_A_update = np.zeros([self.tag_count], dtype=np.float32)
151159
before_corr = correct_tags[0]
152160
before_curr = current_tags[0]
153161
update_init = False
@@ -171,21 +179,22 @@ def cal_loss(self, start, end):
171179
A = self.A.eval(session=self.sess)
172180
init_A = self.init_A.eval(session=self.sess)
173181
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch[start:end], self.tags_batch[start:end])):
174-
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
175-
[len(sentence), self.concat_embed_size])
176-
sentence_score = self.sess.run(self.word_scores, feed_dict={self.x: np.expand_dims(sentence_embeds, 0)})
182+
sentence_score = self.sess.run(self.word_scores, feed_dict={self.sentence_holder: sentence})
177183
loss += self.cal_sentence_loss(tags, sentence_score, A, init_A)
178184
return loss
179185

180-
def seg(self, sentence, model_path='tmp/lstm-model0.ckpt'):
186+
def seg(self, sentence, model_path='tmp/lstm-model0.ckpt', debug=False):
181187
self.saver.restore(self.sess, model_path)
182188
seq = self.index2seq(self.sentence2index(sentence))
183189
sentence_embeds = tf.nn.embedding_lookup(self.embeddings, seq).eval(session=self.sess).reshape(
184190
[len(sentence), self.concat_embed_size])
185-
sentence_scores = self.sess.run(self.word_scores, feed_dict={self.x: np.expand_dims(sentence_embeds, 0)})
191+
sentence_scores = self.sess.run(self.word_scores, feed_dict={self.sentence_holder: seq})
186192
init_A_val = self.init_A.eval(session=self.sess)
187193
A_val = self.A.eval(session=self.sess)
188-
print(A_val)
194+
if debug:
195+
print(A_val)
196+
# print(sentence_embeds[1])
197+
print(sentence_scores.T)
189198
current_tags = self.viterbi(sentence_scores, A_val, init_A_val)
190199
return self.tags2words(sentence, current_tags), current_tags
191200

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /