@@ -35,50 +35,57 @@ def __init__(self):
3535 # 模型定义和初始化
3636 self .sess = tf .Session ()
3737 self .optimizer = tf .train .GradientDescentOptimizer (self .alpha )
38+ # self.optimizer = tf.train.AdamOptimizer(self.alpha)
3839 self .x = tf .placeholder (self .dtype , shape = [1 , None , self .concat_embed_size ])
39- #self.embeddings = tf.Variable(
40+ #self.embeddings = tf.Variable(
4041 # tf.truncated_normal([self.vocab_size, self.embed_size], stddev=-1.0 / math.sqrt(self.embed_size),
41- # dtype=self.dtype), dtype=self.dtype, name='embeddings')
42- self .embeddings = tf .Variable (np .load ('corpus/lstm/embeddings.npy' ), dtype = self .dtype , name = 'embeddings' )
42+ # dtype=self.dtype), name='embeddings')
43+ self .embeddings = tf .Variable (
44+ tf .random_uniform ([self .vocab_size , self .embed_size ], - 1.0 / math .sqrt (self .embed_size ),
45+ 1.0 / math .sqrt (self .embed_size ), dtype = self .dtype ), name = 'embeddings' )
46+ # self.embeddings = tf.Variable(np.load('corpus/lstm/embeddings.npy'), dtype=self.dtype, name='embeddings')
4347 self .w = tf .Variable (
4448 tf .truncated_normal ([self .tags_count , self .hidden_units ], stddev = 1.0 / math .sqrt (self .concat_embed_size ),
45- dtype = self .dtype ), dtype = self .dtype , name = 'w' )
46- self .b = tf .Variable (tf .zeros ([self .tag_count , 1 ], dtype = self .dtype ), dtype = self .dtype , name = 'b' )
47- self .A = tf .Variable (tf .random_uniform ([self .tag_count , self .tag_count ], - 0.05 , 0.05 , dtype = self .dtype ),
48- dtype = self .dtype , name = 'A' )
49+ dtype = self .dtype ), name = 'w' )
50+ self .b = tf .Variable (tf .zeros ([self .tag_count , 1 ], dtype = self .dtype ), name = 'b' )
51+ self .A = tf .Variable (tf .random_uniform ([self .tag_count , self .tag_count ], - 0.05 , 0.05 , dtype = self .dtype ), name = 'A' )
4952 self .Ap = tf .placeholder (self .dtype , shape = self .A .get_shape ())
50- self .init_A = tf .Variable (tf .random_uniform ([self .tag_count ], - 0.05 , 0.05 , dtype = self .dtype ), dtype = self .dtype ,
51- name = 'init_A' )
53+ self .init_A = tf .Variable (tf .random_uniform ([self .tag_count ], - 0.05 , 0.05 , dtype = self .dtype ), name = 'init_A' )
5254 self .init_Ap = tf .placeholder (self .dtype , shape = self .init_A .get_shape ())
53- self .update_A_op = self .A .assign ((1 - self .lam ) * (tf .add (self .A , self .alpha * self .Ap )))
54- self .update_init_A_op = self .init_A .assign ((1 - self .lam ) * (tf .add (self .init_A , self .alpha * self .init_Ap )))
55+ self .update_A_op = self .A .assign (tf .add ((1 - self .alpha * self .lam ) * self .A , self .alpha * self .Ap ))
56+ self .update_init_A_op = self .init_A .assign (
57+ tf .add ((1 - self .alpha * self .lam ) * self .init_A , self .alpha * self .init_Ap ))
5558 self .sentence_holder = tf .placeholder (tf .int32 , shape = [None , self .window_size ])
56- self .lookup_op = tf .nn .embedding_lookup (self .embeddings , self .sentence_holder ).reshape ([- 1 ,self .concat_embed_size ])
59+ self .lookup_op = tf .reshape (tf .nn .embedding_lookup (self .embeddings , self .sentence_holder ),
60+ [- 1 , 1 , self .concat_embed_size ])
5761 self .indices = tf .placeholder (tf .int32 , shape = [None , 2 ])
5862 self .shape = tf .placeholder (tf .int32 , shape = [2 ])
5963 self .values = tf .placeholder (self .dtype , shape = [None ])
6064 self .map_matrix_op = tf .sparse_to_dense (self .indices , self .shape , self .values , validate_indices = False )
6165 self .map_matrix = tf .placeholder (self .dtype , shape = [self .tag_count , None ])
6266 self .lstm = tf .contrib .rnn .LSTMCell (self .hidden_units )
63- self .lstm_output , self .lstm_out_state = tf .nn .dynamic_rnn (self .lstm , self .x , dtype = self .dtype )
64- #self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
67+ # self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x, dtype=self.dtype)
68+ self .lstm_output , self .lstm_out_state = tf .nn .dynamic_rnn (self .lstm , self .lookup_op , dtype = self .dtype ,
69+ time_major = True )
6570 tf .global_variables_initializer ().run (session = self .sess )
66- self .word_scores = tf .matmul (self .w , tf .transpose (self .lstm_output [0 ])) + self .b
67- self .loss_scores = tf .reduce_sum (tf .multiply (self .map_matrix , self .word_scores ),0 )
68- self .loss = tf .reduce_sum (self .loss_scores )
71+ self .word_scores = tf .matmul (self .w , tf .transpose (self .lstm_output [:, - 1 , :])) + self .b
72+ self .loss_scores = tf .reduce_sum (tf .multiply (self .map_matrix , self .word_scores ), 0 )
6973 self .lstm_variable = [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
7074 self .params = [self .w , self .b ] + self .lstm_variable
75+ self .loss = tf .reduce_sum (self .loss_scores ) + tf .contrib .layers .apply_regularization (
76+ tf .contrib .layers .l2_regularizer (self .lam ), self .params + [self .embeddings ])
7177 self .regularization = list (map (lambda p : tf .assign_sub (p , self .lam * p ), self .params ))
72- self .train = self .optimizer .minimize (self .loss , var_list = self .params )
78+ self .train = self .optimizer .minimize (self .loss , var_list = self .params + [self .embeddings ])
79+ # tf.global_variables_initializer().run(session=self.sess)
7380 self .embedp = tf .placeholder (self .dtype , shape = [None , self .embed_size ])
7481 self .embed_index = tf .placeholder (tf .int32 , shape = [None ])
7582 self .update_embed_op = tf .scatter_update (self .embeddings , self .embed_index , self .embedp )
7683 self .sentence_index = 0
7784 self .grad_embed = tf .gradients (self .loss_scores [self .sentence_index ], self .x )
7885 self .saver = tf .train .Saver (self .params + [self .embeddings , self .A , self .init_A ], max_to_keep = 100 )
7986
80- def model (self , embeds ):
81- scores = self .sess .run (self .word_scores , feed_dict = {self .x : np . expand_dims ( embeds , 0 ) })
87+ def model (self , sentence ):
88+ scores = self .sess .run (self .word_scores , feed_dict = {self .sentence_holder : sentence })
8289 path = self .viterbi (scores , self .A .eval (self .sess ), self .init_A .eval (self .sess ))
8390 return path
8491
@@ -88,18 +95,19 @@ def train_exe(self):
8895 for i in range (10 ):
8996 for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch , self .tags_batch )):
9097 self .train_sentence (sentence , tags , len (tags ))
91- if sentence_index > 0 and sentence_index % 500 == 0 :
98+ if sentence_index > 0 and sentence_index % 1000 == 0 :
9299 print (sentence_index )
93100 print (time .time () - last_time )
94101 last_time = time .time ()
95- print (self .cal_loss (sentence_index - 500 ,sentence_index ))
102+ # print(self.cal_loss(sentence_index-500,sentence_index))
96103 print (self .sess .run (self .init_A ))
97104 self .saver .save (self .sess , 'tmp/lstm-model%d.ckpt' % i )
98105
99106 def train_sentence (self , sentence , tags , length ):
100- sentence_embeds = self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
101- [length , self .concat_embed_size ])
102- current_tags = self .model (sentence_embeds )
107+ # sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
108+ # [length, self.concat_embed_size])
109+ # print(sentence_embeds.shape)
110+ current_tags = self .model (sentence )
103111 diff_tags = np .subtract (tags , current_tags )
104112 update_index = np .where (diff_tags != 0 )[0 ]
105113 update_length = len (update_index )
@@ -119,9 +127,10 @@ def train_sentence(self, sentence, tags, length):
119127 self .values : sparse_values })
120128
121129 # 更新参数
122- self .sess .run (self .train ,
123- feed_dict = {self .x : np .expand_dims (sentence_embeds , 0 ), self .map_matrix : sentence_matrix })
124- self .sess .run (self .regularization )
130+ # self.sess.run(self.train,
131+ # feed_dict={self.x: np.expand_dims(sentence_embeds, 0), self.map_matrix: sentence_matrix})
132+ self .sess .run (self .train , feed_dict = {self .sentence_holder : sentence , self .map_matrix : sentence_matrix })
133+ # self.sess.run(self.regularization)
125134
126135 '''
127136 # 更新词向量
@@ -144,10 +153,9 @@ def train_sentence(self, sentence, tags, length):
144153 self .sess .run (self .update_init_A_op , feed_dict = {self .init_Ap : init_A_update })
145154 self .sess .run (self .update_A_op , {self .Ap : A_update })
146155
147- @staticmethod
148- def gen_update_A (correct_tags , current_tags ):
149- A_update = np .zeros ([4 , 4 ], dtype = np .float32 )
150- init_A_update = np .zeros ([4 ], dtype = np .float32 )
156+ def gen_update_A (self , correct_tags , current_tags ):
157+ A_update = np .zeros ([self .tag_count , self .tag_count ], dtype = np .float32 )
158+ init_A_update = np .zeros ([self .tag_count ], dtype = np .float32 )
151159 before_corr = correct_tags [0 ]
152160 before_curr = current_tags [0 ]
153161 update_init = False
@@ -171,21 +179,22 @@ def cal_loss(self, start, end):
171179 A = self .A .eval (session = self .sess )
172180 init_A = self .init_A .eval (session = self .sess )
173181 for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch [start :end ], self .tags_batch [start :end ])):
174- sentence_embeds = self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
175- [len (sentence ), self .concat_embed_size ])
176- sentence_score = self .sess .run (self .word_scores , feed_dict = {self .x : np .expand_dims (sentence_embeds , 0 )})
182+ sentence_score = self .sess .run (self .word_scores , feed_dict = {self .sentence_holder : sentence })
177183 loss += self .cal_sentence_loss (tags , sentence_score , A , init_A )
178184 return loss
179185
180- def seg (self , sentence , model_path = 'tmp/lstm-model0.ckpt' ):
186+ def seg (self , sentence , model_path = 'tmp/lstm-model0.ckpt' , debug = False ):
181187 self .saver .restore (self .sess , model_path )
182188 seq = self .index2seq (self .sentence2index (sentence ))
183189 sentence_embeds = tf .nn .embedding_lookup (self .embeddings , seq ).eval (session = self .sess ).reshape (
184190 [len (sentence ), self .concat_embed_size ])
185- sentence_scores = self .sess .run (self .word_scores , feed_dict = {self .x : np . expand_dims ( sentence_embeds , 0 ) })
191+ sentence_scores = self .sess .run (self .word_scores , feed_dict = {self .sentence_holder : seq })
186192 init_A_val = self .init_A .eval (session = self .sess )
187193 A_val = self .A .eval (session = self .sess )
188- print (A_val )
194+ if debug :
195+ print (A_val )
196+ # print(sentence_embeds[1])
197+ print (sentence_scores .T )
189198 current_tags = self .viterbi (sentence_scores , A_val , init_A_val )
190199 return self .tags2words (sentence , current_tags ), current_tags
191200
0 commit comments