Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit fb66cbd

Browse files
modify word2vec
1 parent 58a202f commit fb66cbd

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

‎word2vec.py‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88

99
class Word2Vec:
10-
def __init__(self, batch_size=128, num_skips=2, skip_window=1, vocab_size=constant.VOCAB_SIZE, embed_size=50,
10+
def __init__(self, output, batch_size=128, num_skips=2, skip_window=1, vocab_size=constant.VOCAB_SIZE, embed_size=50,
1111
num_sampled=64, steps=100000):
12+
self.output = output
1213
self.batch_size = batch_size
1314
self.num_skips = num_skips
1415
self.skip_window = skip_window
@@ -51,9 +52,9 @@ def train(self):
5152
# The average loss is an estimate of the loss over the last 2000 batches.
5253
print("Average loss at step ", step, ": ", aver_loss)
5354
aver_loss = 0
54-
np.save('tmp/embed',self.embeddings.eval())
55-
#self.test(sess)
56-
def test(self,sess):
55+
np.save(self.output, self.embeddings.eval())
56+
57+
def test(self):
5758
valid_dataset = [3021]
5859
norm = tf.sqrt(tf.reduce_sum(tf.square(self.embeddings), 1, keep_dims=True))
5960
normalized_embeddings = self.embeddings / norm
@@ -62,10 +63,11 @@ def test(self,sess):
6263
similarity = tf.abs(tf.matmul(
6364
valid_embeddings, normalized_embeddings, transpose_b=True))
6465
print(similarity.eval())
65-
pair = zip(range(self.vocab_size),similarity.eval()[0])
66+
pair = zip(range(self.vocab_size),similarity.eval()[0])
6667
spair = sorted(pair, key=lambda x: x[1])
6768
print(spair[0:10])
6869

70+
6971
if __name__ == '__main__':
70-
w2v = Word2Vec()
71-
w2v.train()
72+
w2v = Word2Vec('corpus/lstm/embeddings', embed_size=100)
73+
w2v.train()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /