22import  tensorflow  as  tf 
33import  numpy  as  np 
44import  math 
5+ import  time 
6+ 7+ import  constant 
8+ from  transform_data_lstm  import  TransformDataLSTM 
9+ 510
611class  SegLSTM :
712 def  __init__ (self ):
8-  self .dtype  =  tf .float64 
9-  self .skip_window_left  =  1 
10-  self .skip_window_right  =  1 
13+  self .dtype  =  tf .float32 
14+  self .skip_window_left  =  constant . LSTM_SKIP_WINDOW_LEFT 
15+  self .skip_window_right  =  constant . LSTM_SKIP_WINDOW_RIGHT 
1116 self .window_size  =  self .skip_window_left  +  self .skip_window_right  +  1 
1217 self .embed_size  =  50 
1318 self .hidden_units  =  100 
1419 self .tag_count  =  4 
1520 self .concat_embed_size  =  self .window_size  *  self .embed_size 
16-  self .words_batch  =  None 
17-  self .tags_batch  =  None 
18-  self .vocab_size  =  4000 
21+  trans  =  TransformDataLSTM ()
22+  self .words_batch  =  trans .words_batch 
23+  self .tags_batch  =  trans .labels_batch 
24+  self .vocab_size  =  constant .VOCAB_SIZE 
25+  self .alpha  =  0.02 
26+  self .lam  =  0.001 
1927 self .sess  =  tf .Session ()
20-  self .x  =  tf .placeholder (tf .int32 , shape = [self .concat_embed_size , None ])
28+  self .optimizer  =  tf .train .GradientDescentOptimizer (self .alpha )
29+  self .x  =  tf .placeholder (self .dtype , shape = [self .concat_embed_size , None ])
30+  self .x_plus  =  tf .placeholder (self .dtype , shape = [1 , None , self .concat_embed_size ])
2131 self .embeddings  =  tf .Variable (
2232 tf .random_uniform ([self .vocab_size , self .embed_size ], - 1.0  /  math .sqrt (self .embed_size ),
2333 1.0  /  math .sqrt (self .embed_size ),
24-  dtype = tf .float64 ), dtype = tf .float64 , name = 'embeddings' )
25-  self .w  =  tf .Variable (tf .zeros ([self .tag_count ,self .hidden_units ]),dtype = tf .float32 )
26-  self .b  =  tf .Variable (tf .zeros ([self .tag_count ]),dtype = tf .float32 )
27-  self .lstm  =  tf .contrib .rnn .BasicLSTMCell (self .hidden_units ,reuse = True )
28-  self .A  =  tf .Variable (tf .zeros ([self .tag_count ,self .tag_count ]),dtype = tf .float32 )
29-  self .Ap  =  tf .placeholder (tf .float32 ,shape = self .A .get_shape ())
30-  self .init_A  =  tf .Variable (tf .zeros ([self .tag_count ]),dtype = tf .float32 )
31-  self .init_Ap  =  tf .placeholder (tf .float32 ,shape = self .init_A .get_shape ())
34+  dtype = self .dtype ), dtype = self .dtype , name = 'embeddings' )
35+  self .w  =  tf .Variable (tf .zeros ([self .tag_count , self .hidden_units ]), dtype = self .dtype )
36+  self .b  =  tf .Variable (tf .zeros ([self .tag_count , 1 ]), dtype = self .dtype )
37+  self .A  =  tf .Variable (tf .zeros ([self .tag_count , self .tag_count ]), dtype = self .dtype )
38+  self .Ap  =  tf .placeholder (self .dtype , shape = self .A .get_shape ())
39+  self .init_A  =  tf .Variable (tf .zeros ([self .tag_count ]), dtype = self .dtype )
40+  self .init_Ap  =  tf .placeholder (self .dtype , shape = self .init_A .get_shape ())
41+  self .update_A_op  =  (1  -  self .lam ) *  self .A .assign_add (self .alpha  *  self .Ap )
42+  self .update_init_A_op  =  (1  -  self .lam ) *  self .init_A .assign_add (self .alpha  *  self .init_Ap )
3243 self .sentence_holder  =  tf .placeholder (tf .int32 , shape = [None , self .window_size ])
3344 self .lookup_op  =  tf .nn .embedding_lookup (self .embeddings , self .sentence_holder )
45+  self .indices  =  tf .placeholder (tf .int32 , shape = [None , 2 ])
46+  self .shape  =  tf .placeholder (tf .int32 , shape = [2 ])
47+  self .values  =  tf .placeholder (self .dtype , shape = [None ])
48+  self .map_matrix_op  =  tf .sparse_to_dense (self .indices , self .shape , self .values , validate_indices = False )
49+  self .map_matrix  =  tf .placeholder (self .dtype , shape = [self .tag_count , None ], name = 'mm' )
50+  self .lstm  =  tf .contrib .rnn .LSTMCell (self .hidden_units )
51+  self .lstm_output , self .lstm_out_state  =  tf .nn .dynamic_rnn (self .lstm , self .x_plus , dtype = self .dtype )
52+  tf .global_variables_initializer ().run (session = self .sess )
53+  self .word_scores  =  tf .matmul (self .w , tf .transpose (self .lstm_output [0 ])) +  self .b 
54+  self .loss  =  tf .reduce_sum (tf .multiply (self .map_matrix , self .word_scores ))
55+  self .lstm_variable  =  [v  for  v  in  tf .global_variables () if  v .name .startswith ('rnn' )]
56+  self .params  =  [self .w , self .b ]
57+  self .params .extend (self .lstm_variable )
58+  self .regularization  =  list (map (lambda  p : tf .assign_sub (p , self .lam  *  p ), self .params ))
59+  self .train  =  self .optimizer .minimize (self .loss , var_list = self .params )
60+  self .embedp  =  tf .placeholder (self .dtype , shape = [None , self .embed_size ])
61+  self .embed_index  =  tf .placeholder (tf .int32 , shape = [None ])
62+  self .update_embed_op  =  tf .scatter_update (self .embeddings , self .embed_index , self .embedp )
63+  self .grad_embed  =  tf .gradients (tf .multiply (self .map_matrix , self .word_scores ), self .x_plus )
3464
35-  def  model (self ,input ):
36-  output , out_state  =  tf .nn .dynamic_rnn (self .lstm , input , dtype = tf .float32 )
37-  with  tf .variable_scope ("rnn" ):
38-  tf .initialize_local_variables ()
39-  path  =  self .viterbi (output ,self .A .eval (),self .init_A .eval ())
65+  def  model (self , embeds ):
66+  scores  =  self .sess .run (self .word_scores , feed_dict = {self .x_plus : np .expand_dims (embeds .T , 0 )})
67+  path  =  self .viterbi (scores , self .A .eval (self .sess ), self .init_A .eval (self .sess ))
4068 return  path 
4169
42-  def  train (self ):
43-  pass 
70+  def  train_exe (self ):
71+  self .sess .graph .finalize ()
72+  last_time  =  time .time ()
73+  saver  =  tf .train .Saver ([self .embeddings , self .A ,self .init_A ].extend (self .params ), max_to_keep = 100 )
74+  for  sentence_index , (sentence , tags ) in  enumerate (zip (self .words_batch , self .tags_batch )):
75+  self .train_sentence (sentence , tags , len (tags ))
76+  if  sentence_index  %  500  ==  0 :
77+  print (time .time () -  last_time )
78+  last_time  =  time .time ()
4479
45-  def  train_sentence (self ,sentence ,tags ,length ):
46-  sentence_embeds  =  self .sess .run (self .lookup_op ,feed_dict = {self .sentence_holder :sentence }).reshape (
47-  [length , self .concat_embed_size ])
80+ 81+  def  train_sentence (self , sentence , tags , length ):
82+  sentence_embeds  =  self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
83+  [length , self .concat_embed_size ]).T 
4884 current_tags  =  self .model (sentence_embeds )
85+  diff_tags  =  np .subtract (tags , current_tags )
86+  update_index  =  np .where (diff_tags  !=  0 )[0 ]
87+  update_length  =  len (update_index )
88+ 89+  if  update_length  ==  0 :
90+  return 
91+ 92+  update_tags_pos  =  tags [update_index ]
93+  update_tags_neg  =  current_tags [update_index ]
94+ 95+  update_embed  =  sentence_embeds [:, update_index ]
96+  sparse_indices  =  np .stack (
97+  [np .concatenate ([update_tags_pos , update_tags_neg ], axis = - 1 ), np .tile (np .arange (update_length ), [2 ])], axis = - 1 )
98+ 99+  sparse_values  =  np .concatenate ([- 1  *  np .ones (update_length ), np .ones (update_length )])
100+  output_shape  =  [self .tag_count , update_length ]
101+  sentence_matrix  =  self .sess .run (self .map_matrix_op ,
102+  feed_dict = {self .indices : sparse_indices , self .shape : output_shape ,
103+  self .values : sparse_values })
104+  # 更新参数 
105+  self .sess .run (self .train ,
106+  feed_dict = {self .x_plus : np .expand_dims (update_embed .T , 0 ), self .map_matrix : sentence_matrix })
107+  self .sess .run (self .regularization )
49108
109+  # 更新词向量 
110+  embed_index  =  sentence [update_index ]
111+  for  i  in  range (update_length ):
112+  embed  =  np .expand_dims (np .expand_dims (update_embed [:, i ], 0 ),0 )
113+  grad  =  self .sess .run (self .grad_embed , feed_dict = {self .x_plus : embed ,
114+  self .map_matrix : np .expand_dims (sentence_matrix [:, i ], 1 )})[0 ]
50115
116+  sentence_update_embed  =  (embed  +  self .alpha  *  grad ) *  (1  -  self .lam )
117+  self .embeddings  =  self .sess .run (self .update_embed_op ,
118+  feed_dict = {
119+  self .embedp : sentence_update_embed .reshape ([self .window_size , self .embed_size ]),
120+  self .embed_index : embed_index [i , :]})
51121
122+  # 更新转移矩阵 
123+  A_update , init_A_update , update_init  =  self .gen_update_A (tags , current_tags )
124+  if  update_init :
125+  self .sess .run (self .update_init_A_op , feed_dict = {self .init_Ap : init_A_update })
126+  self .sess .run (self .update_A_op , {self .Ap : A_update })
127+ 128+  @staticmethod  
129+  def  gen_update_A (correct_tags , current_tags ):
130+  A_update  =  np .zeros ([4 , 4 ], dtype = np .float32 )
131+  init_A_update  =  np .zeros ([4 ], dtype = np .float32 )
132+  before_corr  =  correct_tags [0 ]
133+  before_curr  =  current_tags [0 ]
134+  update_init  =  False 
135+ 136+  if  before_corr  !=  before_curr :
137+  init_A_update [before_corr ] +=  1 
138+  init_A_update [before_curr ] -=  1 
139+  update_init  =  True 
140+ 141+  for  _ , (corr_tag , curr_tag ) in  enumerate (zip (correct_tags [1 :], current_tags [1 :])):
142+  if  corr_tag  !=  curr_tag  or  before_corr  !=  before_curr :
143+  A_update [before_corr , corr_tag ] +=  1 
144+  A_update [before_curr , curr_tag ] -=  1 
145+  before_corr  =  corr_tag 
146+  before_curr  =  curr_tag 
147+ 148+  return  A_update , init_A_update , update_init 
52149
53150 def  viterbi (self , emission , A , init_A , return_score = False ):
54151 """ 
@@ -82,4 +179,9 @@ def viterbi(self, emission, A, init_A, return_score=False):
82179 if  return_score :
83180 return  corr_path , path_score [max_index , :]
84181 else :
85-  return  corr_path 
182+  return  corr_path 
183+ 184+ 185+ if  __name__  ==  '__main__' :
186+  seg  =  SegLSTM ()
187+  seg .train_exe ()
0 commit comments