Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6e5d5da

Browse files
add base class
1 parent 0f541c9 commit 6e5d5da

File tree

1 file changed

+8
-39
lines changed

1 file changed

+8
-39
lines changed

‎seg_lstm.py‎

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
import constant
88
from transform_data_lstm import TransformDataLSTM
9+
from seg_base import SegBase
910

1011

11-
class SegLSTM:
12+
class SegLSTM(SegBase):
1213
def __init__(self):
14+
SegBase.__init__(self)
1315
self.dtype = tf.float32
1416
self.skip_window_left = constant.LSTM_SKIP_WINDOW_LEFT
1517
self.skip_window_right = constant.LSTM_SKIP_WINDOW_RIGHT
@@ -21,6 +23,7 @@ def __init__(self):
2123
trans = TransformDataLSTM()
2224
self.words_batch = trans.words_batch
2325
self.tags_batch = trans.labels_batch
26+
self.dictionary = trans.dictionary
2427
self.vocab_size = constant.VOCAB_SIZE
2528
self.alpha = 0.02
2629
self.lam = 0.001
@@ -68,15 +71,16 @@ def model(self, embeds):
6871
return path
6972

7073
def train_exe(self):
74+
saver = tf.train.Saver([self.embeddings, self.A, self.init_A].extend(self.params), max_to_keep=100)
7175
self.sess.graph.finalize()
7276
last_time = time.time()
73-
saver = tf.train.Saver([self.embeddings, self.A,self.init_A].extend(self.params), max_to_keep=100)
7477
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
7578
self.train_sentence(sentence, tags, len(tags))
7679
if sentence_index % 500 == 0:
80+
print(sentence_index)
7781
print(time.time() - last_time)
7882
last_time = time.time()
79-
83+
saver.save(self.sess, 'tmp/lstm-model%d.ckpt'%0)
8084

8185
def train_sentence(self, sentence, tags, length):
8286
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
@@ -109,7 +113,7 @@ def train_sentence(self, sentence, tags, length):
109113
# 更新词向量
110114
embed_index = sentence[update_index]
111115
for i in range(update_length):
112-
embed = np.expand_dims(np.expand_dims(update_embed[:, i], 0),0)
116+
embed = np.expand_dims(np.expand_dims(update_embed[:, i], 0),0)
113117
grad = self.sess.run(self.grad_embed, feed_dict={self.x_plus: embed,
114118
self.map_matrix: np.expand_dims(sentence_matrix[:, i], 1)})[0]
115119

@@ -147,41 +151,6 @@ def gen_update_A(correct_tags, current_tags):
147151

148152
return A_update, init_A_update, update_init
149153

150-
def viterbi(self, emission, A, init_A, return_score=False):
151-
"""
152-
维特比算法的实现,所有输入和返回参数均为numpy数组对象
153-
:param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length
154-
:param A: 转移概率矩阵,4*4
155-
:param init_A: 初始转移概率矩阵,4
156-
:param return_score: 是否返回最优路径的分值,默认为False
157-
:return: 最优路径,若return_score为True,返回最优路径及其对应分值
158-
"""
159-
160-
length = emission.shape[1]
161-
path = np.ones([4, length], dtype=np.int32) * -1
162-
corr_path = np.zeros([length], dtype=np.int32)
163-
path_score = np.ones([4, length], dtype=np.float64) * (np.finfo('f').min / 2)
164-
path_score[:, 0] = init_A + emission[:, 0]
165-
166-
for pos in range(1, length):
167-
for t in range(4):
168-
for prev in range(4):
169-
temp = path_score[prev][pos - 1] + A[prev][t] + emission[t][pos]
170-
if temp >= path_score[t][pos]:
171-
path[t][pos] = prev
172-
path_score[t][pos] = temp
173-
174-
max_index = np.argmax(path_score[:, -1])
175-
corr_path[length - 1] = max_index
176-
for i in range(length - 1, 0, -1):
177-
max_index = path[max_index][i]
178-
corr_path[i - 1] = max_index
179-
if return_score:
180-
return corr_path, path_score[max_index, :]
181-
else:
182-
return corr_path
183-
184-
185154
if __name__ == '__main__':
186155
seg = SegLSTM()
187156
seg.train_exe()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /