Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3864a74

Browse files
Add files via upload
1 parent dbe0244 commit 3864a74

File tree

1 file changed

+94
-87
lines changed

1 file changed

+94
-87
lines changed

‎BPE.py‎

Lines changed: 94 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def word_split_for_bpe(word, space_symbol='</w>'):
2626

2727

2828
# word frequency 추출.
29-
def get_word_frequency_dict_from_document(path, space_symbol='</w>', top_k=None):
29+
def get_word_frequency_dict_from_document(path, space_symbol='</w>'):
3030
word_frequency_dict = {}
3131

3232
with open(path, 'r', encoding='utf-8') as f:
@@ -40,27 +40,16 @@ def get_word_frequency_dict_from_document(path, space_symbol='</w>', top_k=None)
4040

4141
for word in sentence.split():
4242
# "abc" => "a b c space_symbol"
43-
split_word = word_split_for_bpe(word, space_symbol)
44-
43+
word = word_split_for_bpe(word, space_symbol)
44+
4545
# word frequency
46-
if split_word in word_frequency_dict:
47-
word_frequency_dict[split_word] += 1
46+
if word in word_frequency_dict:
47+
word_frequency_dict[word] += 1
4848
else:
49-
word_frequency_dict[split_word] = 1
49+
word_frequency_dict[word] = 1
50+
51+
return word_frequency_dict
5052

51-
if top_k is None:
52-
return word_frequency_dict
53-
54-
else:
55-
# top_k frequency word
56-
sorted_word_frequency_list = sorted(
57-
word_frequency_dict.items(), # ('key', value) pair
58-
key=lambda x:x[1], # x: ('key', value), and x[1]: value
59-
reverse=True
60-
) # [('a', 3), ('b', 2), ... ]
61-
top_k_word_frequency_dict = dict(sorted_word_frequency_list[:top_k])
62-
63-
return top_k_word_frequency_dict
6453

6554

6655
# merge two dictionary
@@ -116,45 +105,22 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
116105

117106

118107

119-
120-
# from bpe to idx
121-
def make_bpe2idx(word_frequency_list):
122-
bpe2idx = {
123-
'</p>':0,
124-
'UNK':1,
125-
'</g>':2, #go
126-
'</e>':3 #eos
127-
}
128-
idx2bpe = {
129-
0:'</p>',
130-
1:'UNK',
131-
2:'</g>', #go
132-
3:'</e>' #eos
133-
}
134-
idx = 4
135-
136-
for word, _ in word_frequency_list: # word, freq
137-
for bpe in word.split():
138-
# bpe가 bpe2idx에 없는 경우만 idx 부여.
139-
if bpe not in bpe2idx:
140-
bpe2idx[bpe] = idx
141-
idx2bpe[idx] = bpe
142-
idx += 1
143-
return bpe2idx, idx2bpe
144-
145-
146108
def merge_a_word(merge_info, word, cache={}):
147109
# merge_info: list
148110
# word: "c e m e n t </w>" => "ce m e n t<\w>" 되어야 함.
149111

150-
if len(word.split()) == 1:
112+
#if len(word.split()) == 1:
113+
if word.count(' ') == 0:
151114
return word
152115

153116
if word in cache:
154117
return cache[word]
155118
else:
156119
bpe_word = word
157120
for info in merge_info:
121+
if bpe_word.count(' ') == 0:
122+
break
123+
158124
bigram = re.escape(' '.join(info))
159125
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
160126

@@ -166,6 +132,44 @@ def merge_a_word(merge_info, word, cache={}):
166132
return bpe_word
167133

168134

135+
def make_bpe2idx(word_frequency_list, npy_path):
136+
word_frequency_dict = {}
137+
for word, freq in word_frequency_list:
138+
# ex: ('B e it r a g</w>', 8)
139+
split = word.split() # [B e it r a g</w>]
140+
for bpe in split:
141+
if bpe not in word_frequency_dict:
142+
word_frequency_dict[bpe] = freq
143+
else:
144+
word_frequency_dict[bpe] += freq
145+
146+
sorted_voca = sorted(tuple(word_frequency_dict.items()), key=lambda x: x[1], reverse=True)
147+
148+
bpe2idx = {
149+
'</p>':0,
150+
'UNK':1,
151+
'</g>':2, #go
152+
'</e>':3 #eos
153+
}
154+
idx2bpe = {
155+
0:'</p>',
156+
1:'UNK',
157+
2:'</g>', #go
158+
3:'</e>' #eos
159+
}
160+
idx = 4
161+
162+
with open(npy_path+'sorted_voca.txt', 'w', encoding='utf-8') as o:
163+
for voca, freq in sorted_voca:
164+
o.write(str(voca) + ' ' + str(freq) + '\n')
165+
bpe2idx[voca] = idx
166+
idx2bpe[idx] = voca
167+
idx += 1
168+
169+
return bpe2idx, idx2bpe
170+
171+
172+
169173
# 문서를 읽고, bpe 적용. cache 사용할것. apply_bpe에서 사용.
170174
def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
171175
start = time.time()
@@ -197,15 +201,15 @@ def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
197201
row.extend(merge.split())
198202
wr.writerow(row)
199203

200-
if (i+1) % 500 == 0:
204+
if (i+1) % 100000 == 0:
201205
current_cache_len = len(cache)
202206
print('out_path:', out_path, 'line:', i+1, 'total cache:', current_cache_len, 'added:', current_cache_len-cache_len)
203207
cache_len = current_cache_len
204208

205209
o.close()
206210

207211

208-
def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
212+
def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
209213
#word_frequency_dict = {'l o w </w>' : 1, 'l o w e r </w>' : 1, 'n e w e s t </w>':1, 'w i d e s t </w>':1}
210214

211215
merge_info = [] # 합친 정보를 기억하고있다가 다른 데이터에 적용.
@@ -262,63 +266,67 @@ def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
262266
word_frequency = merge_bpe_word((best, word_frequency)) # 가장 높은 빈도의 2gram을 합침.
263267
######
264268

265-
269+
# multiproc close
266270
if multi_proc > 1:
267271
pool.close()
268272

273+
274+
# make npy
275+
if not os.path.exists(npy_path):
276+
print("create" + npy_path + "directory")
277+
os.makedirs(npy_path)
278+
269279
# 빠른 변환을 위한 cache 저장. 기존 word를 key로, bpe 결과를 value로.
270280
cache = {}
271281
for i in range(len(cache_list)):
272282
key = cache_list[i][0]
273283
value = word_frequency[i][0]
274284
cache[key] = value
275285

276-
# voca 추출.
277-
bpe2idx, idx2bpe = make_bpe2idx(word_frequency)
278-
return bpe2idx, idx2bpe, merge_info, cache # dict, dict, list, dict
286+
save_data(npy_path+'merge_info.npy', merge_info) # list
287+
save_data(npy_path+'cache.npy', cache) # dict
288+
print('save merge_info.npy', ', size:', len(merge_info))
289+
print('save cache.npy', ', size:', len(cache))
290+
291+
292+
bpe2idx, idx2bpe = make_bpe2idx(word_frequency, npy_path)
293+
save_data(npy_path+'bpe2idx.npy', bpe2idx) # dict
294+
save_data(npy_path+'idx2bpe.npy', idx2bpe) # dict
295+
print('save bpe2idx.npy', ', size:', len(bpe2idx))
296+
print('save idx2bpe.npy', ', size:', len(idx2bpe))
279297

280298

281299

282-
def learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None, num_merges=37000, multi_proc=1):
300+
def learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=37000, voca_threshold=5, multi_proc=1):
283301

284302
print('get word frequency dictionary')
285303
total_word_frequency_dict = {}
286304
for path in path_list:
287305
word_frequency_dict = get_word_frequency_dict_from_document(
288306
path=path,
289307
space_symbol=space_symbol,
290-
top_k=top_k#None
291308
) #ok
292309
total_word_frequency_dict = merge_dictionary(total_word_frequency_dict, word_frequency_dict)
293310

294-
'''
295-
save_data('./word_frequency_dictionary.npy', total_word_frequency_dict)
296-
print('save ./word_frequency_dictionary.npy', 'size:', len(total_word_frequency_dict), '\n')
297-
total_word_frequency_dict = load_data('./word_frequency_dictionary.npy', mode='dictionary')
298-
'''
311+
312+
# 빈도수가 일정 미만인 단어 제외.
313+
total_word_frequency_dict_size = len(total_word_frequency_dict)
314+
for item in list(total_word_frequency_dict.items()):
315+
if item[1] < voca_threshold: # item[0] is key, item[1] is value
316+
del total_word_frequency_dict[item[0]]
317+
print('frequency word dict size:', total_word_frequency_dict_size)
318+
print('threshold applied frequency word dict size:', len(total_word_frequency_dict), 'removed:', total_word_frequency_dict_size-len(total_word_frequency_dict), '\n')
319+
299320

300321
print('learn bpe')
301-
check= time.time()
302-
bpe2idx, idx2bpe, merge_info, cache = _learn_bpe(
322+
_learn_bpe(
303323
total_word_frequency_dict,
324+
npy_path=npy_path,
304325
num_merges=num_merges,
305326
multi_proc=multi_proc
306-
)# dict, dict, list, dict
307-
print('multiproc:', multi_proc, 'time:', time.time()-check)
327+
)
308328

309-
if not os.path.exists(npy_path):
310-
print("create" + npy_path + "directory")
311-
os.makedirs(npy_path)
312-
313-
save_data(npy_path+'bpe2idx.npy', bpe2idx)
314-
save_data(npy_path+'idx2bpe.npy', idx2bpe)
315-
save_data(npy_path+'merge_info.npy', merge_info)
316-
save_data(npy_path+'cache.npy', cache)
317-
print('save bpe2idx.npy', 'size:', len(bpe2idx))
318-
print('save idx2bpe.npy', 'size:', len(idx2bpe))
319-
print('save merge_info.npy', 'size:', len(merge_info))
320-
print('save cache.npy', 'size:', len(cache))
321-
print()
329+
print('\n\n\n')
322330

323331

324332

@@ -327,27 +335,25 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
327335
print("create" + out_bpe_path + "directory")
328336
os.makedirs(out_bpe_path)
329337

330-
print('load bpe info')
331338
merge_info = load_data(npy_path+'merge_info.npy')
332339
cache = load_data(npy_path+'cache.npy', mode='dictionary')
333-
340+
341+
print('apply bpe')
334342
for i in range(len(path_list)):
335343
path = path_list[i]
336344
out_path = out_list[i]
337345

338-
print('apply bpe', path, out_path)
346+
print('path:', path, ', out_path:', out_path)
339347
_apply_bpe(
340348
path=path,
341349
out_path=out_bpe_path+out_path,
342350
space_symbol=space_symbol,
343351
merge_info=merge_info,
344352
cache=cache
345353
)
346-
print('save ok', out_path)
347354
save_data(npy_path+'cache.npy', cache)
348-
print('save updated cache ./cache.npy', 'size:', len(cache))
355+
print('\n\n\n')
349356

350-
print()
351357

352358

353359
# save directory
@@ -375,11 +381,12 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
375381
if __name__ == '__main__':
376382
# if don't use multiprocessing:
377383
# learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None)
378-
379-
# multiprocessing, multi_proc: # process, os.cpu_count(): # cpu processor of current computer
380-
# learn bpe from documents
381-
learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None, num_merges=37000, multi_proc=os.cpu_count())
384+
# multi_proc: # process, os.cpu_count(): # cpu processor of current computer
382385

386+
# learn bpe from documents
387+
learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=35000, voca_threshold=50, multi_proc=os.cpu_count())
388+
#learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=30000, voca_threshold=5, multi_proc=os.cpu_count())
389+
383390
# apply bpe to documents
384391
apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>', pad_symbol='</p>')
385392
apply_bpe(test_path_list, out_bpe_path, test_out_list, npy_path, space_symbol='</w>', pad_symbol='</p>')

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /