| 
6 | 6 | 
 
  | 
7 | 7 | 
 
  | 
8 | 8 | class TransformDataW2V(TransformData):  | 
9 |  | - def __init__(self, batch_size, num_skips, skip_window,file,corpus):  | 
 | 9 | + def __init__(self, batch_size, num_skips, skip_window):  | 
10 | 10 |  TransformData.__init__(self, 'corpus/dict.utf8', ['pku'])  | 
11 | 11 |  self.batch_size = batch_size  | 
12 | 12 |  self.num_skips = num_skips  | 
13 | 13 |  self.skip_window = skip_window  | 
14 | 14 |  self.data_index = 0  | 
15 | 15 |  self.span = 2 * self.skip_window + 1  | 
16 |  | - self.words = [itemforsublistinself.words_indexforiteminsublist]  | 
 | 16 | + self.words = self.generate_words('sogou')  | 
17 | 17 |  self.word_count = len(self.words)  | 
18 | 18 | 
 
  | 
 | 19 | + def generate_words(self, name):  | 
 | 20 | + if name == 'pku':  | 
 | 21 | + return [item for sublist in self.words_index for item in sublist]  | 
 | 22 | + elif name == 'sogou':  | 
 | 23 | + with open('corpus/sogou.txt', 'r', encoding='utf8') as file:  | 
 | 24 | + return self.sentence2index(file.read())  | 
 | 25 | + | 
 | 26 | + def sentence2index(self, sentence):  | 
 | 27 | + index = []  | 
 | 28 | + for ch in sentence:  | 
 | 29 | + if ch in self.dictionary:  | 
 | 30 | + index.append(self.dictionary[ch])  | 
 | 31 | + else:  | 
 | 32 | + index.append(0)  | 
 | 33 | + return index  | 
 | 34 | + | 
19 | 35 |  def generate_batch(self):  | 
20 | 36 |  batch = np.ndarray(shape=(self.batch_size), dtype=np.int32)  | 
21 | 37 |  labels = np.ndarray(shape=(self.batch_size, 1), dtype=np.int32)  | 
 | 
0 commit comments