@@ -26,7 +26,7 @@ def word_split_for_bpe(word, space_symbol='</w>'):
2626
2727
2828# word frequency 추출. 
29- def  get_word_frequency_dict_from_document (path , space_symbol = '</w>' ,  top_k = None ):
29+ def  get_word_frequency_dict_from_document (path , space_symbol = '</w>' ):
3030	word_frequency_dict  =  {}
3131
3232	with  open (path , 'r' , encoding = 'utf-8' ) as  f :
@@ -40,27 +40,16 @@ def get_word_frequency_dict_from_document(path, space_symbol='</w>', top_k=None)
4040
4141			for  word  in  sentence .split ():					
4242				# "abc" => "a b c space_symbol" 
43- 				split_word  =  word_split_for_bpe (word , space_symbol )
44- 43+ 				word  =  word_split_for_bpe (word , space_symbol )
44+ 				
4545				# word frequency 
46- 				if  split_word  in  word_frequency_dict :
47- 					word_frequency_dict [split_word ] +=  1 
46+ 				if  word  in  word_frequency_dict :
47+ 					word_frequency_dict [word ] +=  1 
4848				else :
49- 					word_frequency_dict [split_word ] =  1 
49+ 					word_frequency_dict [word ] =  1 
50+ 51+ 	return  word_frequency_dict 
5052
51- 	if  top_k  is  None :
52- 		return  word_frequency_dict 
53- 54- 	else :
55- 		# top_k frequency word 
56- 		sorted_word_frequency_list  =  sorted (
57- 				word_frequency_dict .items (), # ('key', value) pair 
58- 				key = lambda  x :x [1 ], # x: ('key', value), and x[1]: value 
59- 				reverse = True 
60- 			) # [('a', 3), ('b', 2), ... ]  
61- 		top_k_word_frequency_dict  =  dict (sorted_word_frequency_list [:top_k ])
62- 63- 		return  top_k_word_frequency_dict 
6453
6554
6655# merge two dictionary 
@@ -116,45 +105,22 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
116105
117106
118107
119- 120- # from bpe to idx 
121- def  make_bpe2idx (word_frequency_list ):
122- 	bpe2idx  =  {
123- 			'</p>' :0 ,
124- 			'UNK' :1 ,
125- 			'</g>' :2 , #go 
126- 			'</e>' :3  #eos 
127- 		}	
128- 	idx2bpe  =  {
129- 			0 :'</p>' ,
130- 			1 :'UNK' ,
131- 			2 :'</g>' , #go 
132- 			3 :'</e>'  #eos 
133- 		}
134- 	idx  =  4 
135- 136- 	for  word , _  in  word_frequency_list : # word, freq 
137- 		for  bpe  in  word .split ():
138- 			# bpe가 bpe2idx에 없는 경우만 idx 부여. 
139- 			if  bpe  not  in bpe2idx :
140- 				bpe2idx [bpe ] =  idx 
141- 				idx2bpe [idx ] =  bpe 
142- 				idx  +=  1 
143- 	return  bpe2idx , idx2bpe 
144- 145- 146108def  merge_a_word (merge_info , word , cache = {}):
147109	# merge_info: list 
148110	# word: "c e m e n t </w>" => "ce m e n t<\w>" 되어야 함. 
149111
150- 	if  len (word .split ()) ==  1 :
112+ 	#if len(word.split()) == 1: 
113+ 	if  word .count (' ' ) ==  0 :
151114		return  word 
152115
153116	if  word  in  cache :
154117		return  cache [word ]
155118	else :
156119		bpe_word  =  word 
157120		for  info  in  merge_info :
121+ 			if  bpe_word .count (' ' ) ==  0 :
122+ 				break 
123+ 158124			bigram  =  re .escape (' ' .join (info ))
159125			p  =  re .compile (r'(?<!\S)'  +  bigram  +  r'(?!\S)' )
160126
@@ -166,6 +132,44 @@ def merge_a_word(merge_info, word, cache={}):
166132		return  bpe_word 
167133
168134
135+ def  make_bpe2idx (word_frequency_list , npy_path ):
136+ 	word_frequency_dict  =  {}
137+ 	for  word , freq  in  word_frequency_list :
138+ 		# ex: ('B e it r a g</w>', 8) 
139+ 		split  =  word .split () # [B e it r a g</w>] 
140+ 		for  bpe  in  split :
141+ 			if  bpe  not  in word_frequency_dict :
142+ 				word_frequency_dict [bpe ] =  freq 
143+ 			else :
144+ 				word_frequency_dict [bpe ] +=  freq 
145+ 146+ 	sorted_voca  =  sorted (tuple (word_frequency_dict .items ()), key = lambda  x : x [1 ], reverse = True )
147+ 148+ 	bpe2idx  =  {
149+ 			'</p>' :0 ,
150+ 			'UNK' :1 ,
151+ 			'</g>' :2 , #go 
152+ 			'</e>' :3  #eos 
153+ 		}	
154+ 	idx2bpe  =  {
155+ 			0 :'</p>' ,
156+ 			1 :'UNK' ,
157+ 			2 :'</g>' , #go 
158+ 			3 :'</e>'  #eos 
159+ 		}
160+ 	idx  =  4 
161+ 162+ 	with  open (npy_path + 'sorted_voca.txt' , 'w' , encoding = 'utf-8' ) as  o :
163+ 		for  voca , freq  in  sorted_voca :
164+ 			o .write (str (voca ) +  ' '  +  str (freq ) +  '\n ' )
165+ 			bpe2idx [voca ] =  idx 
166+ 			idx2bpe [idx ] =  voca 
167+ 			idx  +=  1 
168+ 169+ 	return  bpe2idx , idx2bpe 
170+ 171+ 172+ 169173# 문서를 읽고, bpe 적용. cache 사용할것. apply_bpe에서 사용. 
170174def  _apply_bpe (path , out_path , space_symbol = '</w>' , merge_info = None , cache = {}):
171175	start  =  time .time ()
@@ -197,15 +201,15 @@ def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
197201				row .extend (merge .split ())
198202			wr .writerow (row )
199203
200- 			if  (i + 1 ) %  500  ==  0 :
204+ 			if  (i + 1 ) %  100000  ==  0 :
201205				current_cache_len  =  len (cache )
202206				print ('out_path:' , out_path , 'line:' , i + 1 , 'total cache:' , current_cache_len , 'added:' , current_cache_len - cache_len )
203207				cache_len  =  current_cache_len 
204208
205209	o .close ()
206210
207211
208- def  _learn_bpe (word_frequency_dict , num_merges = 37000 , multi_proc = 1 ):
212+ def  _learn_bpe (word_frequency_dict , npy_path ,  num_merges = 37000 , multi_proc = 1 ):
209213	#word_frequency_dict = {'l o w </w>' : 1, 'l o w e r </w>' : 1, 'n e w e s t </w>':1, 'w i d e s t </w>':1} 
210214
211215	merge_info  =  [] # 합친 정보를 기억하고있다가 다른 데이터에 적용. 
@@ -262,63 +266,67 @@ def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
262266			word_frequency  =  merge_bpe_word ((best , word_frequency )) # 가장 높은 빈도의 2gram을 합침. 
263267		###### 
264268
265- 269+ # multiproc close 
266270	if  multi_proc  >  1 :		
267271		pool .close ()
268272
273+ 274+ 	# make npy 
275+ 	if  not  os .path .exists (npy_path ):
276+ 		print ("create"  +  npy_path  +  "directory" )
277+ 		os .makedirs (npy_path )
278+ 269279	# 빠른 변환을 위한 cache 저장. 기존 word를 key로, bpe 결과를 value로.	 
270280	cache  =  {}
271281	for  i  in  range (len (cache_list )): 
272282		key  =  cache_list [i ][0 ]
273283		value  =  word_frequency [i ][0 ]
274284		cache [key ] =  value 
275285
276- 	# voca 추출. 
277- 	bpe2idx , idx2bpe  =  make_bpe2idx (word_frequency )
278- 	return  bpe2idx , idx2bpe , merge_info , cache  # dict, dict, list, dict 
286+ 	save_data (npy_path + 'merge_info.npy' , merge_info ) # list 
287+ 	save_data (npy_path + 'cache.npy' , cache ) # dict 
288+ 	print ('save merge_info.npy' , ', size:' , len (merge_info ))
289+ 	print ('save cache.npy' , ', size:' , len (cache ))
290+ 291+ 292+ 	bpe2idx , idx2bpe  =  make_bpe2idx (word_frequency , npy_path )
293+ 	save_data (npy_path + 'bpe2idx.npy' , bpe2idx ) # dict 
294+ 	save_data (npy_path + 'idx2bpe.npy' , idx2bpe ) # dict 
295+ 	print ('save bpe2idx.npy' , ', size:' , len (bpe2idx ))
296+ 	print ('save idx2bpe.npy' , ', size:' , len (idx2bpe ))	
279297
280298
281299
282- def  learn_bpe (path_list , npy_path , space_symbol = '</w>' , top_k = None ,  num_merges = 37000 , multi_proc = 1 ):
300+ def  learn_bpe (path_list , npy_path , space_symbol = '</w>' , num_merges = 37000 ,  voca_threshold = 5 , multi_proc = 1 ):
283301
284302	print ('get word frequency dictionary' )
285303	total_word_frequency_dict  =  {}
286304	for  path  in  path_list :
287305		word_frequency_dict  =  get_word_frequency_dict_from_document (
288306				path = path , 
289307				space_symbol = space_symbol , 
290- 				top_k = top_k #None 
291308			) #ok 
292309		total_word_frequency_dict  =  merge_dictionary (total_word_frequency_dict , word_frequency_dict )
293310
294- 	''' 
295- 	save_data('./word_frequency_dictionary.npy', total_word_frequency_dict) 
296- 	print('save ./word_frequency_dictionary.npy', 'size:', len(total_word_frequency_dict), '\n ') 
297- 	total_word_frequency_dict = load_data('./word_frequency_dictionary.npy', mode='dictionary') 
298- 	''' 
311+ 312+ 	# 빈도수가 일정 미만인 단어 제외.	 
313+ 	total_word_frequency_dict_size  =  len (total_word_frequency_dict )
314+ 	for  item  in  list (total_word_frequency_dict .items ()):
315+ 		if  item [1 ] <  voca_threshold : # item[0] is key, item[1] is value 
316+ 			del  total_word_frequency_dict [item [0 ]]
317+ 	print ('frequency word dict size:' , total_word_frequency_dict_size )
318+ 	print ('threshold applied frequency word dict size:' , len (total_word_frequency_dict ), 'removed:' , total_word_frequency_dict_size - len (total_word_frequency_dict ), '\n ' )
319+ 299320
300321	print ('learn bpe' )
301- 	check =  time .time ()
302- 	bpe2idx , idx2bpe , merge_info , cache  =  _learn_bpe (
322+ 	_learn_bpe (
303323			total_word_frequency_dict , 
324+ 			npy_path = npy_path ,
304325			num_merges = num_merges ,
305326			multi_proc = multi_proc 
306- 		)# dict, dict, list, dict 
307- 	print ('multiproc:' , multi_proc , 'time:' , time .time ()- check )
327+ 		)
308328
309- 	if  not  os .path .exists (npy_path ):
310- 		print ("create"  +  npy_path  +  "directory" )
311- 		os .makedirs (npy_path )
312- 313- 	save_data (npy_path + 'bpe2idx.npy' , bpe2idx )
314- 	save_data (npy_path + 'idx2bpe.npy' , idx2bpe )
315- 	save_data (npy_path + 'merge_info.npy' , merge_info )
316- 	save_data (npy_path + 'cache.npy' , cache )
317- 	print ('save bpe2idx.npy' , 'size:' , len (bpe2idx ))
318- 	print ('save idx2bpe.npy' , 'size:' , len (idx2bpe ))
319- 	print ('save merge_info.npy' , 'size:' , len (merge_info ))
320- 	print ('save cache.npy' , 'size:' , len (cache ))
321- 	print ()
329+ 	print ('\n \n \n ' )
322330
323331
324332
@@ -327,27 +335,25 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
327335		print ("create"  +  out_bpe_path  +  "directory" )
328336		os .makedirs (out_bpe_path )
329337
330- 	print ('load bpe info' )
331338	merge_info  =  load_data (npy_path + 'merge_info.npy' )
332339	cache  =  load_data (npy_path + 'cache.npy' , mode = 'dictionary' )
333- 340+ 341+ 	print ('apply bpe' )
334342	for  i  in  range (len (path_list )):
335343		path  =  path_list [i ]
336344		out_path  =  out_list [i ]
337345
338- 		print ('apply bpe ' , path , out_path )
346+ 		print ('path: ' , path ,  ', out_path:' , out_path )
339347		_apply_bpe (
340348				path = path , 
341349				out_path = out_bpe_path + out_path ,
342350				space_symbol = space_symbol , 
343351				merge_info = merge_info , 
344352				cache = cache 
345353			)
346- 		print ('save ok' , out_path )
347354		save_data (npy_path + 'cache.npy' , cache )
348- 	print ('save updated cache ./cache.npy'  ,  'size:' ,  len ( cache ) )
355+ 	print ('\n \n \n '  )
349356
350- 	print ()
351357
352358
353359# save directory 
@@ -375,11 +381,12 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
375381if  __name__  ==  '__main__' :
376382	# if don't use multiprocessing: 
377383	# learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None) 
378- 379- 	# multiprocessing, multi_proc: # process, os.cpu_count(): # cpu processor of current computer 
380- 	# learn bpe from documents 
381- 	learn_bpe (path_list , npy_path , space_symbol = '</w>' , top_k = None , num_merges = 37000 , multi_proc = os .cpu_count ())
384+ 	# multi_proc: # process, os.cpu_count(): # cpu processor of current computer 
382385
386+ 	# learn bpe from documents 
387+ 	learn_bpe (path_list , npy_path , space_symbol = '</w>' , num_merges = 35000 , voca_threshold = 50 , multi_proc = os .cpu_count ())
388+ 	#learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=30000, voca_threshold=5, multi_proc=os.cpu_count()) 
389+ 383390	# apply bpe to documents 
384391	apply_bpe (path_list , out_bpe_path , out_list , npy_path , space_symbol = '</w>' , pad_symbol = '</p>' )
385392	apply_bpe (test_path_list , out_bpe_path , test_out_list , npy_path , space_symbol = '</w>' , pad_symbol = '</p>' )
0 commit comments