Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 280670a

Browse files
add functions
1 parent 6e5d5da commit 280670a

File tree

1 file changed

+68
-18
lines changed

1 file changed

+68
-18
lines changed

‎seg_base.py‎

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,78 @@ def __init__(self):
77
self.TAGS = np.arange(4)
88
self.TAG_MAPS = np.array([[0, 1], [2, 3], [2, 3], [0, 1]], dtype=np.int32)
99
self.tags_count = len(self.TAG_MAPS)
10+
self.dictionary = {}
11+
self.skip_window_left = 0
12+
self.skip_window_right = 1
1013

11-
def viterbi(self, emission, A, init_A):
14+
def viterbi(self, emission, A, init_A, return_score=False):
1215
"""
13-
维特比算法的实现,
14-
:param emission: 发射概率矩阵,对应于本模型中的分数矩阵
15-
:param A: 转移概率矩阵
16-
:return:
16+
维特比算法的实现,所有输入和返回参数均为numpy数组对象
17+
:param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length
18+
:param A: 转移概率矩阵,4*4
19+
:param init_A: 初始转移概率矩阵,4
20+
:param return_score: 是否返回最优路径的分值,默认为False
21+
:return: 最优路径,若return_score为True,返回最优路径及其对应分值
1722
"""
1823

19-
path = np.array([[0], [1]], dtype=np.int32)
20-
print(emission.shape)
21-
path_score = np.array([[init_A[0] + emission[0, 0]], [init_A[1] + emission[0, 1]]], dtype=np.float32)
24+
length = emission.shape[1]
25+
path = np.ones([4, length], dtype=np.int32) * -1
26+
corr_path = np.zeros([length], dtype=np.int32)
27+
path_score = np.ones([4, length], dtype=np.float64) * (np.finfo('f').min / 2)
28+
path_score[:, 0] = init_A + emission[:, 0]
2229

23-
for line_index in range(1, emission.shape[0]):
24-
last_index = path[:, -1]
25-
cur_index = self.TAG_MAPS[last_index] # 当前所有路径的可选的标记矩阵,2x2
26-
# 当前所有可能路径的分值
27-
cur_res = A[last_index, cur_index] + emission[line_index, cur_index] + np.expand_dims(path_score[:, -1], 1)
28-
cur_max_index = np.argmax(cur_res, 1)
29-
path = np.insert(path, [path.shape[1]], np.expand_dims(np.choose(cur_max_index, cur_index.T), 1), 1)
30-
path_score = np.insert(path_score, [path_score.shape[1]], np.expand_dims(np.choose(cur_max_index, cur_res.T), 1),
31-
1)
30+
for pos in range(1, length):
31+
for t in range(4):
32+
for prev in range(4):
33+
temp = path_score[prev][pos - 1] + A[prev][t] + emission[t][pos]
34+
if temp >= path_score[t][pos]:
35+
path[t][pos] = prev
36+
path_score[t][pos] = temp
3237

3338
max_index = np.argmax(path_score[:, -1])
34-
return path[max_index, :]
39+
corr_path[length - 1] = max_index
40+
for i in range(length - 1, 0, -1):
41+
max_index = path[max_index][i]
42+
corr_path[i - 1] = max_index
43+
if return_score:
44+
return corr_path, path_score[max_index, :]
45+
else:
46+
return corr_path
47+
48+
def sentence2index(self, sentence):
49+
index = []
50+
for word in sentence:
51+
if word not in self.dictionary:
52+
index.append(0)
53+
else:
54+
index.append(self.dictionary[word])
55+
56+
return index
57+
58+
def index2seq(self, indices):
59+
ext_indices = [1] * self.skip_window_left
60+
ext_indices.extend(indices + [2] * self.skip_window_right)
61+
seq = []
62+
for index in range(self.skip_window_left, len(ext_indices) - self.skip_window_right):
63+
seq.append(ext_indices[index - self.skip_window_left: index + self.skip_window_right + 1])
64+
65+
return seq
66+
67+
def tags2words(self, sentence, tags):
68+
words = []
69+
word = ''
70+
for tag_index, tag in enumerate(tags):
71+
if tag == 0:
72+
words.append(sentence[tag_index])
73+
elif tag == 1:
74+
word = sentence[tag_index]
75+
elif tag == 2:
76+
word += sentence[tag_index]
77+
else:
78+
words.append(word + sentence[tag_index])
79+
word = ''
80+
# 处理最后一个标记为I的情况
81+
if word != '':
82+
words.append(word)
83+
84+
return words

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /