@@ -71,6 +71,62 @@ def get_stats(word_frequency_list):
7171	return  pairs  # tuple을 담고 있는 dictionary 리턴. 
7272
7373
74+ def  delete_some_stats (stats , best_pair ):
75+ 	# ac t c t s [c t] => ac t ct s 
76+ 	# stats best_pair new_stats 
77+ 	# [ac, t]	 [c, t]		[ac, t] 
78+ 	# [t, c]					[t, ct] 
79+ 	# [c, t]					[ct, s] 
80+ 	# [t, s] 
81+ 82+ 	# left, right = best_pair 라고 할 때, 
83+ 	# left == info[1] 이거나 right == info[0] 이거나  
84+ 	# (left==info[0] and right==info[1]) 이면 stats에서 제거 
85+ 86+ 	# 만약 best_pair[0]이 stats[1]에 있으면 원래 stats[1] 뒤에 best_pair[1]이 있었으면 
87+ 	# 저 stats 가 만들어질 수 없으므로 재계산. 
88+ 	# 또한 best_pair[1]이 stats[0]에 있으면 원래 stats[0] 앞에 best_pair[1]이 있었다면 
89+ 	# 저 stats가 만들어질 수 없으므로 재계산. 
90+ 	# 또한 stats가 best_pair랑 동일하면 재계산 
91+ 92+ 93+ 	left , right  =  best_pair 	
94+ 	del  stats [best_pair ] # (left == info[0] and right == info[1]) 
95+ 96+ 	for  info  in  list (stats .keys ()):
97+ 		if  left  ==  info [1 ] or  right  ==  info [0 ]:
98+ 			del  stats [info ]
99+ 			#print('delete from', 'info:',info, 'best_pair:', best_pair) 
100+ 	return  stats 
101+ 102+ 103+ 104+ # 2-gram frequency table 추출. (불필요한 2gram freq는 재계산 X) 
105+ def  selective_get_stats (data ):
106+ 	best_pair  =  data [0 ]
107+ 	word_frequency_list  =  data [1 ] 
108+ 109+ 	left , right  =  best_pair 	
110+ 	best_pair_to_string  =  left + right 
111+ 112+ 	# word_frequency_list는 best_pair 기준으로 합쳐졌으므로, best_pair_to_string이 포함된것 계산 
113+ 	# 또한 left == info[1] 이거나 right == info[0] 인것을 제거했으므로 이것들만 계산. 
114+ 	# 다른경우는 계산 x 
115+ 116+ 	stats  =  {}
117+ 	for  word , freq  in  word_frequency_list :
118+ 		symbols  =  word .split ()
119+ 		if  left  in  symbols  or  right  in  symbols  or  best_pair_to_string  in  symbols :
120+ 			for  i  in  range (len (symbols )- 1 ):
121+ 				gram  =  (symbols [i ],symbols [i + 1 ])
122+ 				if  left  ==  gram [1 ] or  right  ==  gram [0 ] or  best_pair_to_string  in  gram :
123+ 					if  gram  in  stats :
124+ 						stats [gram ] +=  freq 
125+ 					else :
126+ 						stats [gram ] =  freq 
127+ 	return  stats 
128+ 129+ 74130
75131# pairs 중에서 가장 높은 frequency를 갖는 key 리턴. 
76132def  check_merge_info (pairs ):
@@ -93,7 +149,8 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
93149	p  =  re .compile (r'(?<!\S)'  +  bigram  +  r'(?!\S)' )
94150
95151	for  word , freq  in  word_frequency :
96- 		if  best_pair_to_string_with_space  in  word : 
152+ 		#if best_pair_to_string_with_space in word:  
153+ 		if  ' ' + best_pair_to_string_with_space + ' '  in  ' ' + word + ' ' : 
97154			w_out  =  p .sub (best_pair_to_string , word ) # 만약 ''.join(best_pair): r</w> 이고, word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다. 
98155			v_out .append ( (w_out , freq ) )
99156		else :
@@ -128,10 +185,9 @@ def merge_a_word(merge_info, word, cache={}, high_freq_voca={}):
128185			if  len (split_bpe_word ) ==  1 : # 더이상 merge할 것이 없는 상황. 
129186				break 
130187
131- 			# 이건 완벽하게 일치하는것만 실행 하지는 않지만 시간체크해보면 완벽하게 체크하는것보다 더 빠름.  
132- 			# (info: ['m', 'c'], word: 'm cd' 이면 merge할 것이 없지만 if문에서는 true여서 merge수행함.) 
133188			info_to_string_with_space  =  ' ' .join (info )
134- 			if  info_to_string_with_space  in  bpe_word : 
189+ 			#if info_to_string_with_space in bpe_word:  
190+ 			if  ' ' + info_to_string_with_space + ' '  in  ' ' + bpe_word + ' ' : 
135191				bigram  =  re .escape (info_to_string_with_space )
136192				p  =  re .compile (r'(?<!\S)'  +  bigram  +  r'(?!\S)' )
137193				bpe_word  =  p .sub ('' .join (info ), bpe_word ) # 만약 info_to_string_with_space: 'r </w>' 이고, bpe_word: 'a r </w>' 이면 w_out은 'a r</w>'가 된다. 
@@ -168,8 +224,10 @@ def _make_total_word_cache_before_apply_bpe(data):
168224
169225				if  len (split_bpe_word ) ==  1 : # 더이상 merge할 것이 없는 상황. 
170226					break 
227+ 171228				info_to_string_with_space  =  ' ' .join (info )
172- 				if  info_to_string_with_space  in  bpe_word : 
229+ 				#if info_to_string_with_space in bpe_word:  
230+ 				if  ' ' + info_to_string_with_space + ' '  in  ' ' + bpe_word + ' ' : 
173231					bigram  =  re .escape (info_to_string_with_space )
174232					p  =  re .compile (r'(?<!\S)'  +  bigram  +  r'(?!\S)' )
175233					bpe_word  =  p .sub ('' .join (info ), bpe_word )
@@ -222,7 +280,7 @@ def make_total_word_cache_before_apply_bpe(path_list, npy_path, space_symbol='</
222280				zip ( multi_merge_info , [total_words [slicing [k ]:slicing [k + 1 ]] for  k  in  range (process )], multi_high_freq_voca  )
223281			)
224282		for  dic  in  results :
225- 			cache = merge_dictionary ( cache ,  dic )
283+ 			cache . update ( dic )
226284
227285		pool .close ()
228286
@@ -295,15 +353,26 @@ def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
295353		print ('multiproc data slicing boundary:' , slicing )
296354		pool  =  mp .Pool (process )
297355
356+ 		import  time 
298357		for  i  in  tqdm (range (num_merges ), ncols = 50 ):
358+ 299359			# 2gram별 빈도수 추출 
300- 			get_stats_results  =  pool .map (
301- 					get_stats , 
302- 					[word_frequency [slicing [k ]:slicing [k + 1 ]] for  k  in  range (process )]
303- 				)
304- 			pairs = {} # merge  
305- 			for  dic  in  get_stats_results :
306- 				pairs  =  merge_dictionary (pairs , dic )
360+ 			if  i  ==  0 :
361+ 				get_stats_results  =  pool .map (
362+ 						get_stats , 
363+ 						[word_frequency [slicing [k ]:slicing [k + 1 ]] for  k  in  range (process )]
364+ 					)
365+ 				pairs = {} # merge  
366+ 				for  dic  in  get_stats_results :
367+ 					pairs  =  merge_dictionary (pairs , dic )
368+ 			else :
369+ 				pairs  =  delete_some_stats (pairs , best )
370+ 				selective_result  =  pool .map (
371+ 						selective_get_stats , 
372+ 						zip ( [best ]* process , [word_frequency [slicing [k ]:slicing [k + 1 ]] for  k  in  range (process )] )
373+ 					)
374+ 				for  dic  in  selective_result :
375+ 					pairs  =  merge_dictionary (pairs , dic )				
307376			####### 
308377
309378			# 가장 높은 빈도의 2gram 선정 
@@ -327,7 +396,12 @@ def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
327396	else :
328397		for  i  in  tqdm (range (num_merges ), ncols = 50 ):
329398			# 2gram별 빈도수 추출 
330- 			pairs  =  get_stats (word_frequency ) 
399+ 			if  i  ==  0 :
400+ 				pairs  =  get_stats (word_frequency ) 
401+ 			else :
402+ 				pairs  =  delete_some_stats (pairs , best )
403+ 				selective_pairs  =  selective_get_stats ([best , word_frequency ])
404+ 				pairs .update (selective_pairs )
331405
332406			# 가장 높은 빈도의 2gram 선정 
333407			best  =  check_merge_info (pairs ) # 가장 높은 빈도의 2gram 선정 
@@ -443,12 +517,11 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, final_voca_threshold=
443517	] 		
444518
445519multi_proc  =  os .cpu_count ()
446- voca_threshold  =  5  # 빠른 학습을 위해 일정 빈도수 이하의 단어는 bpe learn에 참여시키지 않음. 
447- final_voca_threshold  =  50  # bpe learn으로 학습된 voca중에서 final voca에 참여시킬 voca의 threshold 
448- 449520if  multi_proc  >  1 :
450521	import  multiprocessing  as  mp 
451522
523+ voca_threshold  =  5  # 빠른 학습을 위해 일정 빈도수 이하의 단어는 bpe learn에 참여시키지 않음. 
524+ final_voca_threshold  =  50  # bpe learn으로 학습된 voca중에서 final voca에 참여시킬 voca의 threshold 
452525
453526# learn and apply 
454527if  __name__  ==  '__main__' :
@@ -458,10 +531,12 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, final_voca_threshold=
458531
459532	# learn bpe from documents 
460533	learn_bpe (path_list , npy_path , space_symbol = '</w>' , num_merges = 35000 , voca_threshold = voca_threshold , multi_proc = multi_proc )
461- 534+ 
462535	# multi_proc으로 미리 cache 생성해 둠으로써 단순 apply_bpe하는것보다 빠름. 
463536	make_total_word_cache_before_apply_bpe (path_list , npy_path , multi_proc = multi_proc )
464537
465538	# apply bpe to documents 
466539	apply_bpe (path_list , out_bpe_path , out_list , npy_path , final_voca_threshold = final_voca_threshold , space_symbol = '</w>' , pad_symbol = '</p>' )
467540	apply_bpe (test_path_list , out_bpe_path , test_out_list , npy_path , final_voca_threshold = final_voca_threshold , space_symbol = '</w>' , pad_symbol = '</p>' )
541+ 542+ 
0 commit comments