Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 05c5ffa

Browse files
Add files via upload
1 parent 3864a74 commit 05c5ffa

File tree

1 file changed

+79
-89
lines changed

1 file changed

+79
-89
lines changed

‎BPE.py‎

Lines changed: 79 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def get_word_frequency_dict_from_document(path, space_symbol='</w>'):
5151
return word_frequency_dict
5252

5353

54-
5554
# merge two dictionary
5655
def merge_dictionary(dic_a, dic_b):
5756
for i in dic_b:
@@ -94,17 +93,47 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
9493
bigram = re.escape(' '.join(best_pair))
9594
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
9695
for word, freq in word_frequency:
97-
# 만약 ''.join(best_pair): r</w> 이고, word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
98-
w_out = p.sub(''.join(best_pair), word)
99-
v_out.append( (w_out, freq) )
100-
96+
best_pair_to_string = ''.join(best_pair)
97+
if best_pair_to_string in ''.join(word):
98+
# 만약 ''.join(best_pair): r</w> 이고, word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
99+
w_out = p.sub(best_pair_to_string, word)
100+
v_out.append( (w_out, freq) )
101+
else:
102+
v_out.append( (word, freq) )
101103
if len(best_pair_and_word_frequency_list) == 3: # multi proc
102104
return (best_pair_and_word_frequency_list[2], v_out) # (multiproc 결과 조합할 순서, 결과)
103105
else:
104106
return v_out
105107

106108

107109

110+
111+
# from bpe to idx
112+
def make_bpe2idx(word_frequency_list):
113+
bpe2idx = {
114+
'</p>':0,
115+
'UNK':1,
116+
'</g>':2, #go
117+
'</e>':3 #eos
118+
}
119+
idx2bpe = {
120+
0:'</p>',
121+
1:'UNK',
122+
2:'</g>', #go
123+
3:'</e>' #eos
124+
}
125+
idx = 4
126+
127+
for word, _ in word_frequency_list: # word, freq
128+
for bpe in word.split():
129+
# bpe가 bpe2idx에 없는 경우만 idx 부여.
130+
if bpe not in bpe2idx:
131+
bpe2idx[bpe] = idx
132+
idx2bpe[idx] = bpe
133+
idx += 1
134+
return bpe2idx, idx2bpe
135+
136+
108137
def merge_a_word(merge_info, word, cache={}):
109138
# merge_info: list
110139
# word: "c e m e n t </w>" => "ce m e n t<\w>" 되어야 함.
@@ -120,56 +149,20 @@ def merge_a_word(merge_info, word, cache={}):
120149
for info in merge_info:
121150
if bpe_word.count(' ') == 0:
122151
break
152+
info_to_string = ''.join(info)
153+
if info_to_string in ''.join(bpe_word):
123154

124-
bigram = re.escape(' '.join(info))
125-
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
155+
bigram = re.escape(' '.join(info))
156+
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
126157

127-
# 만약 ''.join(info): r</w> 이고, bpe_word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
128-
bpe_word = p.sub(''.join(info), bpe_word)
158+
# 만약 info_to_string: r</w> 이고, bpe_word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
159+
bpe_word = p.sub(info_to_string, bpe_word)
129160

130161
# cache upate
131162
cache[word] = bpe_word
132163
return bpe_word
133164

134165

135-
def make_bpe2idx(word_frequency_list, npy_path):
136-
word_frequency_dict = {}
137-
for word, freq in word_frequency_list:
138-
# ex: ('B e it r a g</w>', 8)
139-
split = word.split() # [B e it r a g</w>]
140-
for bpe in split:
141-
if bpe not in word_frequency_dict:
142-
word_frequency_dict[bpe] = freq
143-
else:
144-
word_frequency_dict[bpe] += freq
145-
146-
sorted_voca = sorted(tuple(word_frequency_dict.items()), key=lambda x: x[1], reverse=True)
147-
148-
bpe2idx = {
149-
'</p>':0,
150-
'UNK':1,
151-
'</g>':2, #go
152-
'</e>':3 #eos
153-
}
154-
idx2bpe = {
155-
0:'</p>',
156-
1:'UNK',
157-
2:'</g>', #go
158-
3:'</e>' #eos
159-
}
160-
idx = 4
161-
162-
with open(npy_path+'sorted_voca.txt', 'w', encoding='utf-8') as o:
163-
for voca, freq in sorted_voca:
164-
o.write(str(voca) + ' ' + str(freq) + '\n')
165-
bpe2idx[voca] = idx
166-
idx2bpe[idx] = voca
167-
idx += 1
168-
169-
return bpe2idx, idx2bpe
170-
171-
172-
173166
# 문서를 읽고, bpe 적용. cache 사용할것. apply_bpe에서 사용.
174167
def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
175168
start = time.time()
@@ -201,15 +194,15 @@ def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
201194
row.extend(merge.split())
202195
wr.writerow(row)
203196

204-
if (i+1) % 100000 == 0:
197+
if (i+1) % 1000000 == 0:
205198
current_cache_len = len(cache)
206199
print('out_path:', out_path, 'line:', i+1, 'total cache:', current_cache_len, 'added:', current_cache_len-cache_len)
207200
cache_len = current_cache_len
208201

209202
o.close()
210203

211204

212-
def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
205+
def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
213206
#word_frequency_dict = {'l o w </w>' : 1, 'l o w e r </w>' : 1, 'n e w e s t </w>':1, 'w i d e s t </w>':1}
214207

215208
merge_info = [] # 합친 정보를 기억하고있다가 다른 데이터에 적용.
@@ -266,38 +259,24 @@ def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
266259
word_frequency = merge_bpe_word((best, word_frequency)) # 가장 높은 빈도의 2gram을 합침.
267260
######
268261

269-
# multiproc close
262+
270263
if multi_proc > 1:
271264
pool.close()
272265

273-
274-
# make npy
275-
if not os.path.exists(npy_path):
276-
print("create" + npy_path + "directory")
277-
os.makedirs(npy_path)
278-
279266
# 빠른 변환을 위한 cache 저장. 기존 word를 key로, bpe 결과를 value로.
280267
cache = {}
281268
for i in range(len(cache_list)):
282269
key = cache_list[i][0]
283270
value = word_frequency[i][0]
284271
cache[key] = value
285272

286-
save_data(npy_path+'merge_info.npy', merge_info) # list
287-
save_data(npy_path+'cache.npy', cache) # dict
288-
print('save merge_info.npy', ', size:', len(merge_info))
289-
print('save cache.npy', ', size:', len(cache))
290-
291-
292-
bpe2idx, idx2bpe = make_bpe2idx(word_frequency, npy_path)
293-
save_data(npy_path+'bpe2idx.npy', bpe2idx) # dict
294-
save_data(npy_path+'idx2bpe.npy', idx2bpe) # dict
295-
print('save bpe2idx.npy', ', size:', len(bpe2idx))
296-
print('save idx2bpe.npy', ', size:', len(idx2bpe))
273+
# voca 추출.
274+
bpe2idx, idx2bpe = make_bpe2idx(word_frequency)
275+
return bpe2idx, idx2bpe, merge_info, cache # dict, dict, list, dict
297276

298277

299278

300-
def learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=37000, voca_threshold=5, multi_proc=1):
279+
def learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=37000, multi_proc=1):
301280

302281
print('get word frequency dictionary')
303282
total_word_frequency_dict = {}
@@ -309,24 +288,32 @@ def learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=37000, voca_t
309288
total_word_frequency_dict = merge_dictionary(total_word_frequency_dict, word_frequency_dict)
310289

311290

312-
# 빈도수가 일정 미만인 단어 제외.
313-
total_word_frequency_dict_size = len(total_word_frequency_dict)
314-
for item in list(total_word_frequency_dict.items()):
315-
if item[1] < voca_threshold: # item[0] is key, item[1] is value
316-
del total_word_frequency_dict[item[0]]
317-
print('frequency word dict size:', total_word_frequency_dict_size)
318-
print('threshold applied frequency word dict size:', len(total_word_frequency_dict), 'removed:', total_word_frequency_dict_size-len(total_word_frequency_dict), '\n')
319-
291+
'''
292+
save_data('./word_frequency_dictionary.npy', total_word_frequency_dict)
293+
print('save ./word_frequency_dictionary.npy', 'size:', len(total_word_frequency_dict), '\n')
294+
total_word_frequency_dict = load_data('./word_frequency_dictionary.npy', mode='dictionary')
295+
'''
320296

321297
print('learn bpe')
322-
_learn_bpe(
298+
bpe2idx, idx2bpe, merge_info, cache=_learn_bpe(
323299
total_word_frequency_dict,
324-
npy_path=npy_path,
325300
num_merges=num_merges,
326301
multi_proc=multi_proc
327-
)
302+
)# dict, dict, list, dict
328303

329-
print('\n\n\n')
304+
if not os.path.exists(npy_path):
305+
print("create" + npy_path + "directory")
306+
os.makedirs(npy_path)
307+
308+
save_data(npy_path+'bpe2idx.npy', bpe2idx)
309+
save_data(npy_path+'idx2bpe.npy', idx2bpe)
310+
save_data(npy_path+'merge_info.npy', merge_info)
311+
save_data(npy_path+'cache.npy', cache)
312+
print('save bpe2idx.npy', 'size:', len(bpe2idx))
313+
print('save idx2bpe.npy', 'size:', len(idx2bpe))
314+
print('save merge_info.npy', 'size:', len(merge_info))
315+
print('save cache.npy', 'size:', len(cache))
316+
print()
330317

331318

332319

@@ -335,25 +322,27 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
335322
print("create" + out_bpe_path + "directory")
336323
os.makedirs(out_bpe_path)
337324

325+
print('load bpe info')
338326
merge_info = load_data(npy_path+'merge_info.npy')
339327
cache = load_data(npy_path+'cache.npy', mode='dictionary')
340-
341-
print('apply bpe')
328+
342329
for i in range(len(path_list)):
343330
path = path_list[i]
344331
out_path = out_list[i]
345332

346-
print('path:', path, ', out_path:', out_path)
333+
print('apply bpe', path, out_path)
347334
_apply_bpe(
348335
path=path,
349336
out_path=out_bpe_path+out_path,
350337
space_symbol=space_symbol,
351338
merge_info=merge_info,
352339
cache=cache
353340
)
341+
print('save ok', out_path)
354342
save_data(npy_path+'cache.npy', cache)
355-
print('\n\n\n')
356-
343+
print('save updated cache ./cache.npy', 'size:', len(cache))
344+
print()
345+
print()
357346

358347

359348
# save directory
@@ -379,13 +368,14 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
379368

380369
# learn and apply
381370
if __name__ == '__main__':
371+
print('20190105_test')
382372
# if don't use multiprocessing:
383373
# learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None)
384-
# multi_proc: # process, os.cpu_count(): # cpu processor of current computer
385-
374+
375+
# multiprocessing, multi_proc: # process, os.cpu_count(): # cpu processor of current computer
386376
# learn bpe from documents
387-
learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=35000, voca_threshold=50, multi_proc=os.cpu_count())
388-
#learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=30000, voca_threshold=5, multi_proc=os.cpu_count())
377+
learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=30000, multi_proc=os.cpu_count())
378+
# num_merges:37000 => 40297개,
389379

390380
# apply bpe to documents
391381
apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>', pad_symbol='</p>')

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /