77
88
99class Word2Vec :
10- def __init__ (self , batch_size = 128 , num_skips = 2 , skip_window = 1 , vocab_size = constant .VOCAB_SIZE , embed_size = 50 ,
10+ def __init__ (self , output , batch_size = 128 , num_skips = 2 , skip_window = 1 , vocab_size = constant .VOCAB_SIZE , embed_size = 50 ,
1111 num_sampled = 64 , steps = 100000 ):
12+ self .output = output
1213 self .batch_size = batch_size
1314 self .num_skips = num_skips
1415 self .skip_window = skip_window
@@ -51,9 +52,9 @@ def train(self):
5152 # The average loss is an estimate of the loss over the last 2000 batches.
5253 print ("Average loss at step " , step , ": " , aver_loss )
5354 aver_loss = 0
54- np .save ('tmp/embed' , self .embeddings .eval ())
55- #self.test(sess)
56- def test (self , sess ):
55+ np .save (self . output , self .embeddings .eval ())
56+ 57+ def test (self ):
5758 valid_dataset = [3021 ]
5859 norm = tf .sqrt (tf .reduce_sum (tf .square (self .embeddings ), 1 , keep_dims = True ))
5960 normalized_embeddings = self .embeddings / norm
@@ -62,10 +63,11 @@ def test(self,sess):
6263 similarity = tf .abs (tf .matmul (
6364 valid_embeddings , normalized_embeddings , transpose_b = True ))
6465 print (similarity .eval ())
65- pair = zip (range (self .vocab_size ),similarity .eval ()[0 ])
66+ pair = zip (range (self .vocab_size ),similarity .eval ()[0 ])
6667 spair = sorted (pair , key = lambda x : x [1 ])
6768 print (spair [0 :10 ])
6869
70+ 6971if __name__ == '__main__' :
70- w2v = Word2Vec ()
71- w2v .train ()
72+ w2v = Word2Vec ('corpus/lstm/embeddings' , embed_size = 100 )
73+ w2v .train ()
0 commit comments