Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0583dc5

Browse files
add calculate loss function
1 parent ee91f0e commit 0583dc5

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

‎seg_base.py‎

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#-*- coding: UTF-8 -*-
1+
#-*- coding: UTF-8 -*-
22
import numpy as np
3+
import math
34

45

56
class SegBase:
@@ -41,7 +42,7 @@ def viterbi(self, emission, A, init_A, return_score=False):
4142
max_index = path[max_index][i]
4243
corr_path[i - 1] = max_index
4344
if return_score:
44-
return corr_path, path_score[max_index, :]
45+
return corr_path, path_score[max_index, -1]
4546
else:
4647
return corr_path
4748

@@ -81,4 +82,18 @@ def tags2words(self, sentence, tags):
8182
if word != '':
8283
words.append(word)
8384

84-
return words
85+
return words
86+
87+
def cal_sentence_loss(self, tags, sentence_scores, A, init_A):
88+
89+
_, score = self.viterbi(sentence_scores, A, init_A, True)
90+
loss = 0.0
91+
before = 0
92+
for index, (corr_tag, scores) in enumerate(zip(tags, sentence_scores.T)):
93+
if index == 0:
94+
loss += scores[corr_tag] + init_A[corr_tag]
95+
else:
96+
loss += scores[corr_tag] + A[before, corr_tag]
97+
before = corr_tag
98+
99+
return math.fabs(loss - score)

‎seg_lstm.py‎

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ def train_exe(self):
8888
for i in range(10):
8989
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
9090
self.train_sentence(sentence, tags, len(tags))
91-
if sentence_index % 500 == 0:
91+
if sentence_index >0andsentence_index% 500 == 0:
9292
print(sentence_index)
9393
print(time.time() - last_time)
9494
last_time = time.time()
95+
print(self.cal_loss(sentence_index-500,sentence_index))
9596
print(self.sess.run(self.init_A))
9697
self.saver.save(self.sess, 'tmp/lstm-model%d.ckpt' % i)
9798

@@ -164,6 +165,17 @@ def gen_update_A(correct_tags, current_tags):
164165

165166
return A_update, init_A_update, update_init
166167

168+
def cal_loss(self, start, end):
169+
loss = 0.0
170+
A = self.A.eval(session=self.sess)
171+
init_A = self.init_A.eval(session=self.sess)
172+
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch[start:end], self.tags_batch[start:end])):
173+
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
174+
[len(sentence), self.concat_embed_size])
175+
sentence_score = self.sess.run(self.word_scores, feed_dict={self.x: np.expand_dims(sentence_embeds, 0)})
176+
loss += self.cal_sentence_loss(tags, sentence_score, A, init_A)
177+
return loss
178+
167179
def seg(self, sentence, model_path='tmp/lstm-model0.ckpt'):
168180
self.saver.restore(self.sess, model_path)
169181
seq = self.index2seq(self.sentence2index(sentence))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /