@@ -88,10 +88,11 @@ def train_exe(self):
8888 for i in range (10 ):
8989 for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch , self .tags_batch )):
9090 self .train_sentence (sentence , tags , len (tags ))
91- if sentence_index % 500 == 0 :
91+ if sentence_index > 0 and sentence_index % 500 == 0 :
9292 print (sentence_index )
9393 print (time .time () - last_time )
9494 last_time = time .time ()
95+ print (self .cal_loss (sentence_index - 500 ,sentence_index ))
9596 print (self .sess .run (self .init_A ))
9697 self .saver .save (self .sess , 'tmp/lstm-model%d.ckpt' % i )
9798
@@ -164,6 +165,17 @@ def gen_update_A(correct_tags, current_tags):
164165
165166 return A_update , init_A_update , update_init
166167
168+ def cal_loss (self , start , end ):
169+ loss = 0.0
170+ A = self .A .eval (session = self .sess )
171+ init_A = self .init_A .eval (session = self .sess )
172+ for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch [start :end ], self .tags_batch [start :end ])):
173+ sentence_embeds = self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
174+ [len (sentence ), self .concat_embed_size ])
175+ sentence_score = self .sess .run (self .word_scores , feed_dict = {self .x : np .expand_dims (sentence_embeds , 0 )})
176+ loss += self .cal_sentence_loss (tags , sentence_score , A , init_A )
177+ return loss
178+ 167179 def seg (self , sentence , model_path = 'tmp/lstm-model0.ckpt' ):
168180 self .saver .restore (self .sess , model_path )
169181 seq = self .index2seq (self .sentence2index (sentence ))
0 commit comments