@@ -7,28 +7,78 @@ def __init__(self):
77 self .TAGS  =  np .arange (4 )
88 self .TAG_MAPS  =  np .array ([[0 , 1 ], [2 , 3 ], [2 , 3 ], [0 , 1 ]], dtype = np .int32 )
99 self .tags_count  =  len (self .TAG_MAPS )
10+  self .dictionary  =  {}
11+  self .skip_window_left  =  0 
12+  self .skip_window_right  =  1 
1013
11-  def  viterbi (self , emission , A , init_A ):
14+  def  viterbi (self , emission , A , init_A ,  return_score = False ):
1215 """ 
13-  维特比算法的实现, 
14-  :param emission: 发射概率矩阵,对应于本模型中的分数矩阵 
15-  :param A: 转移概率矩阵 
16-  :return: 
16+  维特比算法的实现,所有输入和返回参数均为numpy数组对象 
17+  :param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length 
18+  :param A: 转移概率矩阵,4*4 
19+  :param init_A: 初始转移概率矩阵,4 
20+  :param return_score: 是否返回最优路径的分值,默认为False 
21+  :return: 最优路径,若return_score为True,返回最优路径及其对应分值 
1722 """ 
1823
19-  path  =  np .array ([[0 ], [1 ]], dtype = np .int32 )
20-  print (emission .shape )
21-  path_score  =  np .array ([[init_A [0 ] +  emission [0 , 0 ]], [init_A [1 ] +  emission [0 , 1 ]]], dtype = np .float32 )
24+  length  =  emission .shape [1 ]
25+  path  =  np .ones ([4 , length ], dtype = np .int32 ) *  - 1 
26+  corr_path  =  np .zeros ([length ], dtype = np .int32 )
27+  path_score  =  np .ones ([4 , length ], dtype = np .float64 ) *  (np .finfo ('f' ).min  /  2 )
28+  path_score [:, 0 ] =  init_A  +  emission [:, 0 ]
2229
23-  for  line_index  in  range (1 , emission .shape [0 ]):
24-  last_index  =  path [:, - 1 ]
25-  cur_index  =  self .TAG_MAPS [last_index ] # 当前所有路径的可选的标记矩阵,2x2 
26-  # 当前所有可能路径的分值 
27-  cur_res  =  A [last_index , cur_index ] +  emission [line_index , cur_index ] +  np .expand_dims (path_score [:, - 1 ], 1 )
28-  cur_max_index  =  np .argmax (cur_res , 1 )
29-  path  =  np .insert (path , [path .shape [1 ]], np .expand_dims (np .choose (cur_max_index , cur_index .T ), 1 ), 1 )
30-  path_score  =  np .insert (path_score , [path_score .shape [1 ]], np .expand_dims (np .choose (cur_max_index , cur_res .T ), 1 ),
31-  1 )
30+  for  pos  in  range (1 , length ):
31+  for  t  in  range (4 ):
32+  for  prev  in  range (4 ):
33+  temp  =  path_score [prev ][pos  -  1 ] +  A [prev ][t ] +  emission [t ][pos ]
34+  if  temp  >=  path_score [t ][pos ]:
35+  path [t ][pos ] =  prev 
36+  path_score [t ][pos ] =  temp 
3237
3338 max_index  =  np .argmax (path_score [:, - 1 ])
34-  return  path [max_index , :]
39+  corr_path [length  -  1 ] =  max_index 
40+  for  i  in  range (length  -  1 , 0 , - 1 ):
41+  max_index  =  path [max_index ][i ]
42+  corr_path [i  -  1 ] =  max_index 
43+  if  return_score :
44+  return  corr_path , path_score [max_index , :]
45+  else :
46+  return  corr_path 
47+ 48+  def  sentence2index (self , sentence ):
49+  index  =  []
50+  for  word  in  sentence :
51+  if  word  not  in self .dictionary :
52+  index .append (0 )
53+  else :
54+  index .append (self .dictionary [word ])
55+ 56+  return  index 
57+ 58+  def  index2seq (self , indices ):
59+  ext_indices  =  [1 ] *  self .skip_window_left 
60+  ext_indices .extend (indices  +  [2 ] *  self .skip_window_right )
61+  seq  =  []
62+  for  index  in  range (self .skip_window_left , len (ext_indices ) -  self .skip_window_right ):
63+  seq .append (ext_indices [index  -  self .skip_window_left : index  +  self .skip_window_right  +  1 ])
64+ 65+  return  seq 
66+ 67+  def  tags2words (self , sentence , tags ):
68+  words  =  []
69+  word  =  '' 
70+  for  tag_index , tag  in  enumerate (tags ):
71+  if  tag  ==  0 :
72+  words .append (sentence [tag_index ])
73+  elif  tag  ==  1 :
74+  word  =  sentence [tag_index ]
75+  elif  tag  ==  2 :
76+  word  +=  sentence [tag_index ]
77+  else :
78+  words .append (word  +  sentence [tag_index ])
79+  word  =  '' 
80+  # 处理最后一个标记为I的情况 
81+  if  word  !=  '' :
82+  words .append (word )
83+ 84+  return  words 
0 commit comments