66
77import  constant 
88from  transform_data_lstm  import  TransformDataLSTM 
9+ from  seg_base  import  SegBase 
910
1011
11- class  SegLSTM :
12+ class  SegLSTM ( SegBase ) :
1213 def  __init__ (self ):
14+  SegBase .__init__ (self )
1315 self .dtype  =  tf .float32 
1416 self .skip_window_left  =  constant .LSTM_SKIP_WINDOW_LEFT 
1517 self .skip_window_right  =  constant .LSTM_SKIP_WINDOW_RIGHT 
@@ -21,6 +23,7 @@ def __init__(self):
2123 trans  =  TransformDataLSTM ()
2224 self .words_batch  =  trans .words_batch 
2325 self .tags_batch  =  trans .labels_batch 
26+  self .dictionary  =  trans .dictionary 
2427 self .vocab_size  =  constant .VOCAB_SIZE 
2528 self .alpha  =  0.02 
2629 self .lam  =  0.001 
@@ -68,15 +71,16 @@ def model(self, embeds):
6871 return  path 
6972
7073 def  train_exe (self ):
74+  saver  =  tf .train .Saver ([self .embeddings , self .A , self .init_A ].extend (self .params ), max_to_keep = 100 )
7175 self .sess .graph .finalize ()
7276 last_time  =  time .time ()
73-  saver  =  tf .train .Saver ([self .embeddings , self .A ,self .init_A ].extend (self .params ), max_to_keep = 100 )
7477 for  sentence_index , (sentence , tags ) in  enumerate (zip (self .words_batch , self .tags_batch )):
7578 self .train_sentence (sentence , tags , len (tags ))
7679 if  sentence_index  %  500  ==  0 :
80+  print (sentence_index )
7781 print (time .time () -  last_time )
7882 last_time  =  time .time ()
79- 83+ saver . save ( self . sess ,  'tmp/lstm-model%d.ckpt' % 0 ) 
8084
8185 def  train_sentence (self , sentence , tags , length ):
8286 sentence_embeds  =  self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
@@ -109,7 +113,7 @@ def train_sentence(self, sentence, tags, length):
109113 # 更新词向量 
110114 embed_index  =  sentence [update_index ]
111115 for  i  in  range (update_length ):
112-  embed  =  np .expand_dims (np .expand_dims (update_embed [:, i ], 0 ),0 )
116+  embed  =  np .expand_dims (np .expand_dims (update_embed [:, i ], 0 ),0 )
113117 grad  =  self .sess .run (self .grad_embed , feed_dict = {self .x_plus : embed ,
114118 self .map_matrix : np .expand_dims (sentence_matrix [:, i ], 1 )})[0 ]
115119
@@ -147,41 +151,6 @@ def gen_update_A(correct_tags, current_tags):
147151
148152 return  A_update , init_A_update , update_init 
149153
150-  def  viterbi (self , emission , A , init_A , return_score = False ):
151-  """ 
152-  维特比算法的实现,所有输入和返回参数均为numpy数组对象 
153-  :param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length 
154-  :param A: 转移概率矩阵,4*4 
155-  :param init_A: 初始转移概率矩阵,4 
156-  :param return_score: 是否返回最优路径的分值,默认为False 
157-  :return: 最优路径,若return_score为True,返回最优路径及其对应分值 
158-  """ 
159- 160-  length  =  emission .shape [1 ]
161-  path  =  np .ones ([4 , length ], dtype = np .int32 ) *  - 1 
162-  corr_path  =  np .zeros ([length ], dtype = np .int32 )
163-  path_score  =  np .ones ([4 , length ], dtype = np .float64 ) *  (np .finfo ('f' ).min  /  2 )
164-  path_score [:, 0 ] =  init_A  +  emission [:, 0 ]
165- 166-  for  pos  in  range (1 , length ):
167-  for  t  in  range (4 ):
168-  for  prev  in  range (4 ):
169-  temp  =  path_score [prev ][pos  -  1 ] +  A [prev ][t ] +  emission [t ][pos ]
170-  if  temp  >=  path_score [t ][pos ]:
171-  path [t ][pos ] =  prev 
172-  path_score [t ][pos ] =  temp 
173- 174-  max_index  =  np .argmax (path_score [:, - 1 ])
175-  corr_path [length  -  1 ] =  max_index 
176-  for  i  in  range (length  -  1 , 0 , - 1 ):
177-  max_index  =  path [max_index ][i ]
178-  corr_path [i  -  1 ] =  max_index 
179-  if  return_score :
180-  return  corr_path , path_score [max_index , :]
181-  else :
182-  return  corr_path 
183- 184- 185154if  __name__  ==  '__main__' :
186155 seg  =  SegLSTM ()
187156 seg .train_exe ()
0 commit comments