Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 2997481

Browse files
Add files via upload
1 parent 21ab77b commit 2997481

File tree

1 file changed

+93
-18
lines changed

1 file changed

+93
-18
lines changed

‎BPE.py‎

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,62 @@ def get_stats(word_frequency_list):
7171
return pairs # tuple을 담고 있는 dictionary 리턴.
7272

7373

74+
def delete_some_stats(stats, best_pair):
75+
# ac t c t s [c t] => ac t ct s
76+
# stats best_pair new_stats
77+
# [ac, t] [c, t] [ac, t]
78+
# [t, c] [t, ct]
79+
# [c, t] [ct, s]
80+
# [t, s]
81+
82+
# left, right = best_pair 라고 할 때,
83+
# left == info[1] 이거나 right == info[0] 이거나
84+
# (left==info[0] and right==info[1]) 이면 stats에서 제거
85+
86+
# 만약 best_pair[0]이 stats[1]에 있으면 원래 stats[1] 뒤에 best_pair[1]이 있었으면
87+
# 저 stats 가 만들어질 수 없으므로 재계산.
88+
# 또한 best_pair[1]이 stats[0]에 있으면 원래 stats[0] 앞에 best_pair[1]이 있었다면
89+
# 저 stats가 만들어질 수 없으므로 재계산.
90+
# 또한 stats가 best_pair랑 동일하면 재계산
91+
92+
93+
left, right = best_pair
94+
del stats[best_pair] # (left == info[0] and right == info[1])
95+
96+
for info in list(stats.keys()):
97+
if left == info[1] or right == info[0]:
98+
del stats[info]
99+
#print('delete from', 'info:',info, 'best_pair:', best_pair)
100+
return stats
101+
102+
103+
104+
# 2-gram frequency table 추출. (불필요한 2gram freq는 재계산 X)
105+
def selective_get_stats(data):
106+
best_pair = data[0]
107+
word_frequency_list = data[1]
108+
109+
left, right = best_pair
110+
best_pair_to_string = left+right
111+
112+
# word_frequency_list는 best_pair 기준으로 합쳐졌으므로, best_pair_to_string이 포함된것 계산
113+
# 또한 left == info[1] 이거나 right == info[0] 인것을 제거했으므로 이것들만 계산.
114+
# 다른경우는 계산 x
115+
116+
stats = {}
117+
for word, freq in word_frequency_list:
118+
symbols = word.split()
119+
if left in symbols or right in symbols or best_pair_to_string in symbols:
120+
for i in range(len(symbols)-1):
121+
gram = (symbols[i],symbols[i+1])
122+
if left == gram[1] or right == gram[0] or best_pair_to_string in gram:
123+
if gram in stats:
124+
stats[gram] += freq
125+
else:
126+
stats[gram] = freq
127+
return stats
128+
129+
74130

75131
# pairs 중에서 가장 높은 frequency를 갖는 key 리턴.
76132
def check_merge_info(pairs):
@@ -93,7 +149,8 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
93149
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
94150

95151
for word, freq in word_frequency:
96-
if best_pair_to_string_with_space in word:
152+
#if best_pair_to_string_with_space in word:
153+
if ' '+best_pair_to_string_with_space+' ' in ' '+word+' ':
97154
w_out = p.sub(best_pair_to_string, word) # 만약 ''.join(best_pair): r</w> 이고, word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
98155
v_out.append( (w_out, freq) )
99156
else:
@@ -128,10 +185,9 @@ def merge_a_word(merge_info, word, cache={}, high_freq_voca={}):
128185
if len(split_bpe_word) == 1: # 더이상 merge할 것이 없는 상황.
129186
break
130187

131-
# 이건 완벽하게 일치하는것만 실행 하지는 않지만 시간체크해보면 완벽하게 체크하는것보다 더 빠름.
132-
# (info: ['m', 'c'], word: 'm cd' 이면 merge할 것이 없지만 if문에서는 true여서 merge수행함.)
133188
info_to_string_with_space = ' '.join(info)
134-
if info_to_string_with_space in bpe_word:
189+
#if info_to_string_with_space in bpe_word:
190+
if ' '+info_to_string_with_space+' ' in ' '+bpe_word+' ':
135191
bigram = re.escape(info_to_string_with_space)
136192
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
137193
bpe_word = p.sub(''.join(info), bpe_word) # 만약 info_to_string_with_space: 'r </w>' 이고, bpe_word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다.
@@ -168,8 +224,10 @@ def _make_total_word_cache_before_apply_bpe(data):
168224

169225
if len(split_bpe_word) == 1: # 더이상 merge할 것이 없는 상황.
170226
break
227+
171228
info_to_string_with_space = ' '.join(info)
172-
if info_to_string_with_space in bpe_word:
229+
#if info_to_string_with_space in bpe_word:
230+
if ' '+info_to_string_with_space+' ' in ' '+bpe_word+' ':
173231
bigram = re.escape(info_to_string_with_space)
174232
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
175233
bpe_word = p.sub(''.join(info), bpe_word)
@@ -222,7 +280,7 @@ def make_total_word_cache_before_apply_bpe(path_list, npy_path, space_symbol='</
222280
zip( multi_merge_info, [total_words[slicing[k]:slicing[k+1]] for k in range(process)], multi_high_freq_voca )
223281
)
224282
for dic in results:
225-
cache=merge_dictionary(cache, dic)
283+
cache.update(dic)
226284

227285
pool.close()
228286

@@ -295,15 +353,26 @@ def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
295353
print('multiproc data slicing boundary:', slicing)
296354
pool = mp.Pool(process)
297355

356+
import time
298357
for i in tqdm(range(num_merges), ncols=50):
358+
299359
# 2gram별 빈도수 추출
300-
get_stats_results = pool.map(
301-
get_stats,
302-
[word_frequency[slicing[k]:slicing[k+1]] for k in range(process)]
303-
)
304-
pairs={} # merge
305-
for dic in get_stats_results:
306-
pairs = merge_dictionary(pairs, dic)
360+
if i == 0:
361+
get_stats_results = pool.map(
362+
get_stats,
363+
[word_frequency[slicing[k]:slicing[k+1]] for k in range(process)]
364+
)
365+
pairs={} # merge
366+
for dic in get_stats_results:
367+
pairs = merge_dictionary(pairs, dic)
368+
else:
369+
pairs = delete_some_stats(pairs, best)
370+
selective_result = pool.map(
371+
selective_get_stats,
372+
zip( [best]*process, [word_frequency[slicing[k]:slicing[k+1]] for k in range(process)] )
373+
)
374+
for dic in selective_result:
375+
pairs = merge_dictionary(pairs, dic)
307376
#######
308377

309378
# 가장 높은 빈도의 2gram 선정
@@ -327,7 +396,12 @@ def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
327396
else:
328397
for i in tqdm(range(num_merges), ncols=50):
329398
# 2gram별 빈도수 추출
330-
pairs = get_stats(word_frequency)
399+
if i == 0:
400+
pairs = get_stats(word_frequency)
401+
else:
402+
pairs = delete_some_stats(pairs, best)
403+
selective_pairs = selective_get_stats([best, word_frequency])
404+
pairs.update(selective_pairs)
331405

332406
# 가장 높은 빈도의 2gram 선정
333407
best = check_merge_info(pairs) # 가장 높은 빈도의 2gram 선정
@@ -443,12 +517,11 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, final_voca_threshold=
443517
]
444518

445519
multi_proc = os.cpu_count()
446-
voca_threshold = 5 # 빠른 학습을 위해 일정 빈도수 이하의 단어는 bpe learn에 참여시키지 않음.
447-
final_voca_threshold = 50 # bpe learn으로 학습된 voca중에서 final voca에 참여시킬 voca의 threshold
448-
449520
if multi_proc > 1:
450521
import multiprocessing as mp
451522

523+
voca_threshold = 5 # 빠른 학습을 위해 일정 빈도수 이하의 단어는 bpe learn에 참여시키지 않음.
524+
final_voca_threshold = 50 # bpe learn으로 학습된 voca중에서 final voca에 참여시킬 voca의 threshold
452525

453526
# learn and apply
454527
if __name__ == '__main__':
@@ -458,10 +531,12 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, final_voca_threshold=
458531

459532
# learn bpe from documents
460533
learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=35000, voca_threshold=voca_threshold, multi_proc=multi_proc)
461-
534+
462535
# multi_proc으로 미리 cache 생성해 둠으로써 단순 apply_bpe하는것보다 빠름.
463536
make_total_word_cache_before_apply_bpe(path_list, npy_path, multi_proc=multi_proc)
464537

465538
# apply bpe to documents
466539
apply_bpe(path_list, out_bpe_path, out_list, npy_path, final_voca_threshold=final_voca_threshold, space_symbol='</w>', pad_symbol='</p>')
467540
apply_bpe(test_path_list, out_bpe_path, test_out_list, npy_path, final_voca_threshold=final_voca_threshold, space_symbol='</w>', pad_symbol='</p>')
541+
542+

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /