Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0f541c9

Browse files
implement basic train logic
1 parent 37db075 commit 0f541c9

File tree

1 file changed

+128
-26
lines changed

1 file changed

+128
-26
lines changed

‎seg_lstm.py‎

Lines changed: 128 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,150 @@
22
import tensorflow as tf
33
import numpy as np
44
import math
5+
import time
6+
7+
import constant
8+
from transform_data_lstm import TransformDataLSTM
9+
510

611
class SegLSTM:
712
def __init__(self):
8-
self.dtype = tf.float64
9-
self.skip_window_left = 1
10-
self.skip_window_right = 1
13+
self.dtype = tf.float32
14+
self.skip_window_left = constant.LSTM_SKIP_WINDOW_LEFT
15+
self.skip_window_right = constant.LSTM_SKIP_WINDOW_RIGHT
1116
self.window_size = self.skip_window_left + self.skip_window_right + 1
1217
self.embed_size = 50
1318
self.hidden_units = 100
1419
self.tag_count = 4
1520
self.concat_embed_size = self.window_size * self.embed_size
16-
self.words_batch = None
17-
self.tags_batch = None
18-
self.vocab_size = 4000
21+
trans = TransformDataLSTM()
22+
self.words_batch = trans.words_batch
23+
self.tags_batch = trans.labels_batch
24+
self.vocab_size = constant.VOCAB_SIZE
25+
self.alpha = 0.02
26+
self.lam = 0.001
1927
self.sess = tf.Session()
20-
self.x = tf.placeholder(tf.int32, shape=[self.concat_embed_size, None])
28+
self.optimizer = tf.train.GradientDescentOptimizer(self.alpha)
29+
self.x = tf.placeholder(self.dtype, shape=[self.concat_embed_size, None])
30+
self.x_plus = tf.placeholder(self.dtype, shape=[1, None, self.concat_embed_size])
2131
self.embeddings = tf.Variable(
2232
tf.random_uniform([self.vocab_size, self.embed_size], -1.0 / math.sqrt(self.embed_size),
2333
1.0 / math.sqrt(self.embed_size),
24-
dtype=tf.float64), dtype=tf.float64, name='embeddings')
25-
self.w = tf.Variable(tf.zeros([self.tag_count,self.hidden_units]),dtype=tf.float32)
26-
self.b = tf.Variable(tf.zeros([self.tag_count]),dtype=tf.float32)
27-
self.lstm = tf.contrib.rnn.BasicLSTMCell(self.hidden_units,reuse=True)
28-
self.A = tf.Variable(tf.zeros([self.tag_count,self.tag_count]),dtype=tf.float32)
29-
self.Ap = tf.placeholder(tf.float32,shape=self.A.get_shape())
30-
self.init_A = tf.Variable(tf.zeros([self.tag_count]),dtype=tf.float32)
31-
self.init_Ap = tf.placeholder(tf.float32,shape=self.init_A.get_shape())
34+
dtype=self.dtype), dtype=self.dtype, name='embeddings')
35+
self.w = tf.Variable(tf.zeros([self.tag_count, self.hidden_units]), dtype=self.dtype)
36+
self.b = tf.Variable(tf.zeros([self.tag_count, 1]), dtype=self.dtype)
37+
self.A = tf.Variable(tf.zeros([self.tag_count, self.tag_count]), dtype=self.dtype)
38+
self.Ap = tf.placeholder(self.dtype, shape=self.A.get_shape())
39+
self.init_A = tf.Variable(tf.zeros([self.tag_count]), dtype=self.dtype)
40+
self.init_Ap = tf.placeholder(self.dtype, shape=self.init_A.get_shape())
41+
self.update_A_op = (1 - self.lam) * self.A.assign_add(self.alpha * self.Ap)
42+
self.update_init_A_op = (1 - self.lam) * self.init_A.assign_add(self.alpha * self.init_Ap)
3243
self.sentence_holder = tf.placeholder(tf.int32, shape=[None, self.window_size])
3344
self.lookup_op = tf.nn.embedding_lookup(self.embeddings, self.sentence_holder)
45+
self.indices = tf.placeholder(tf.int32, shape=[None, 2])
46+
self.shape = tf.placeholder(tf.int32, shape=[2])
47+
self.values = tf.placeholder(self.dtype, shape=[None])
48+
self.map_matrix_op = tf.sparse_to_dense(self.indices, self.shape, self.values, validate_indices=False)
49+
self.map_matrix = tf.placeholder(self.dtype, shape=[self.tag_count, None], name='mm')
50+
self.lstm = tf.contrib.rnn.LSTMCell(self.hidden_units)
51+
self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x_plus, dtype=self.dtype)
52+
tf.global_variables_initializer().run(session=self.sess)
53+
self.word_scores = tf.matmul(self.w, tf.transpose(self.lstm_output[0])) + self.b
54+
self.loss = tf.reduce_sum(tf.multiply(self.map_matrix, self.word_scores))
55+
self.lstm_variable = [v for v in tf.global_variables() if v.name.startswith('rnn')]
56+
self.params = [self.w, self.b]
57+
self.params.extend(self.lstm_variable)
58+
self.regularization = list(map(lambda p: tf.assign_sub(p, self.lam * p), self.params))
59+
self.train = self.optimizer.minimize(self.loss, var_list=self.params)
60+
self.embedp = tf.placeholder(self.dtype, shape=[None, self.embed_size])
61+
self.embed_index = tf.placeholder(tf.int32, shape=[None])
62+
self.update_embed_op = tf.scatter_update(self.embeddings, self.embed_index, self.embedp)
63+
self.grad_embed = tf.gradients(tf.multiply(self.map_matrix, self.word_scores), self.x_plus)
3464

35-
def model(self,input):
36-
output, out_state = tf.nn.dynamic_rnn(self.lstm, input, dtype=tf.float32)
37-
with tf.variable_scope("rnn"):
38-
tf.initialize_local_variables()
39-
path = self.viterbi(output,self.A.eval(),self.init_A.eval())
65+
def model(self, embeds):
66+
scores = self.sess.run(self.word_scores, feed_dict={self.x_plus: np.expand_dims(embeds.T, 0)})
67+
path = self.viterbi(scores, self.A.eval(self.sess), self.init_A.eval(self.sess))
4068
return path
4169

42-
def train(self):
43-
pass
70+
def train_exe(self):
71+
self.sess.graph.finalize()
72+
last_time = time.time()
73+
saver = tf.train.Saver([self.embeddings, self.A,self.init_A].extend(self.params), max_to_keep=100)
74+
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
75+
self.train_sentence(sentence, tags, len(tags))
76+
if sentence_index % 500 == 0:
77+
print(time.time() - last_time)
78+
last_time = time.time()
4479

45-
def train_sentence(self,sentence,tags,length):
46-
sentence_embeds = self.sess.run(self.lookup_op,feed_dict={self.sentence_holder:sentence}).reshape(
47-
[length, self.concat_embed_size])
80+
81+
def train_sentence(self, sentence, tags, length):
82+
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
83+
[length, self.concat_embed_size]).T
4884
current_tags = self.model(sentence_embeds)
85+
diff_tags = np.subtract(tags, current_tags)
86+
update_index = np.where(diff_tags != 0)[0]
87+
update_length = len(update_index)
88+
89+
if update_length == 0:
90+
return
91+
92+
update_tags_pos = tags[update_index]
93+
update_tags_neg = current_tags[update_index]
94+
95+
update_embed = sentence_embeds[:, update_index]
96+
sparse_indices = np.stack(
97+
[np.concatenate([update_tags_pos, update_tags_neg], axis=-1), np.tile(np.arange(update_length), [2])], axis=-1)
98+
99+
sparse_values = np.concatenate([-1 * np.ones(update_length), np.ones(update_length)])
100+
output_shape = [self.tag_count, update_length]
101+
sentence_matrix = self.sess.run(self.map_matrix_op,
102+
feed_dict={self.indices: sparse_indices, self.shape: output_shape,
103+
self.values: sparse_values})
104+
# 更新参数
105+
self.sess.run(self.train,
106+
feed_dict={self.x_plus: np.expand_dims(update_embed.T, 0), self.map_matrix: sentence_matrix})
107+
self.sess.run(self.regularization)
49108

109+
# 更新词向量
110+
embed_index = sentence[update_index]
111+
for i in range(update_length):
112+
embed = np.expand_dims(np.expand_dims(update_embed[:, i], 0),0)
113+
grad = self.sess.run(self.grad_embed, feed_dict={self.x_plus: embed,
114+
self.map_matrix: np.expand_dims(sentence_matrix[:, i], 1)})[0]
50115

116+
sentence_update_embed = (embed + self.alpha * grad) * (1 - self.lam)
117+
self.embeddings = self.sess.run(self.update_embed_op,
118+
feed_dict={
119+
self.embedp: sentence_update_embed.reshape([self.window_size, self.embed_size]),
120+
self.embed_index: embed_index[i, :]})
51121

122+
# 更新转移矩阵
123+
A_update, init_A_update, update_init = self.gen_update_A(tags, current_tags)
124+
if update_init:
125+
self.sess.run(self.update_init_A_op, feed_dict={self.init_Ap: init_A_update})
126+
self.sess.run(self.update_A_op, {self.Ap: A_update})
127+
128+
@staticmethod
129+
def gen_update_A(correct_tags, current_tags):
130+
A_update = np.zeros([4, 4], dtype=np.float32)
131+
init_A_update = np.zeros([4], dtype=np.float32)
132+
before_corr = correct_tags[0]
133+
before_curr = current_tags[0]
134+
update_init = False
135+
136+
if before_corr != before_curr:
137+
init_A_update[before_corr] += 1
138+
init_A_update[before_curr] -= 1
139+
update_init = True
140+
141+
for _, (corr_tag, curr_tag) in enumerate(zip(correct_tags[1:], current_tags[1:])):
142+
if corr_tag != curr_tag or before_corr != before_curr:
143+
A_update[before_corr, corr_tag] += 1
144+
A_update[before_curr, curr_tag] -= 1
145+
before_corr = corr_tag
146+
before_curr = curr_tag
147+
148+
return A_update, init_A_update, update_init
52149

53150
def viterbi(self, emission, A, init_A, return_score=False):
54151
"""
@@ -82,4 +179,9 @@ def viterbi(self, emission, A, init_A, return_score=False):
82179
if return_score:
83180
return corr_path, path_score[max_index, :]
84181
else:
85-
return corr_path
182+
return corr_path
183+
184+
185+
if __name__ == '__main__':
186+
seg = SegLSTM()
187+
seg.train_exe()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /