Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 034295b

Browse files
modify test code
1 parent e8c3d8e commit 034295b

File tree

1 file changed

+13
-52
lines changed

1 file changed

+13
-52
lines changed

‎test.py‎

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,31 @@
11
# -*- coding: UTF-8 -*-
22
import numpy as np
3-
#import jieba
43
from seg_dnn import SegDNN
54
from seg_lstm import SegLSTM
65
from utils import estimate_cws
76
import constant
87

98

109
def test_seg_dnn():
11-
'''
1210
cws = SegDNN(constant.VOCAB_SIZE, 50, constant.DNN_SKIP_WINDOW)
13-
sentence = '迈向充满希望的新世纪'
14-
# model = 'tmp/4.29-w3-normal/model15.ckpt'
15-
model = 'tmp/4.29-100/model99.ckpt'
1611
model = 'tmp/model0.ckpt'
1712
print(cws.seg('小明来自南京师范大学', model))
1813
print(cws.seg('小明是上海理工大学的学生', model))
1914
print(cws.seg('小明是清华大学的学生', model))
2015
print(cws.seg('我爱北京天安门', model))
2116
print(cws.seg('上海理工大学', model))
2217
print(cws.seg('上海海洋大学'))
23-
print(cws.seg(sentence, model))
24-
#print('/'.join(jieba.cut('小明是上海理工大学的学生')))
25-
seq = cws.index2seq(cws.sentence2index(sentence))
26-
seq = np.array(seq, dtype=np.int32).flatten()
27-
'''
28-
seg = SegLSTM()
29-
# seg.train_exe()
30-
#print(seg.seg('我爱北京天安门'))
31-
#print(seg.seg('小明来自南京师范大学'))
32-
#print(seg.seg('小明是上海理工大学的学生'))
33-
print(seg.seg('小明来自南京师范大学'))
34-
print(seg.seg('小明是上海理工大学的学生'))
35-
test(seg,'tmp/lstm-model0.ckpt')
36-
# print(seq)
37-
# cal_val(seq)
38-
# print(cws.seg('2015世界旅游小姐大赛山东赛区冠军总决赛在威海举行',model))
18+
print(cws.seg('迈向充满希望的新世纪', model))
3919

4020

41-
def cal_val(seq):
42-
embeddings = np.load('data/dnn/embeddings.npy')
43-
w2 = np.load('data/dnn/w2.npy')
44-
w3 = np.load('data/dnn/w3.npy')
45-
b2 = np.load('data/dnn/b2.npy')
46-
b3 = np.load('data/dnn/b3.npy')
47-
# b2 = np.expand_dims(b2.flatten(),0)
48-
# A = np.load('data/dnn/A.npy')
49-
# init_A = np.load('data/dnn/init_A.npy')
50-
# print(w2)
51-
# x = np.reshape(embeddings[seq],[10,150])
52-
# print(x[0])
53-
# b2 = np.tile(b2,[10])
54-
# print(np.matmul(w2,x.T))
55-
# print(b2.shape)
56-
# val = np.matmul(w3,sigmoid(np.add(np.matmul(w2,x.T),b2)))+b3
57-
# val = w3*sigmoid(w2*x.T+b2)+b3
58-
# print(val.T)
21+
def test_seg_lstm():
22+
seg = SegLSTM()
23+
model = 'tmp/lstm-model1.ckpt'
24+
print(seg.seg('小明来自南京师范大学', model, debug=True))
25+
print(seg.seg('小明是上海理工大学的学生', model))
26+
print(seg.seg('迈向充满希望的新世纪', model))
27+
print(seg.seg('2015世界旅游小姐大赛山东赛区冠军总决赛在威海举行', model))
28+
test(seg, model)
5929

6030

6131
def test(cws, model):
@@ -66,31 +36,22 @@ def test(cws, model):
6636
corr_count = 0
6737
re_count = 0
6838
total_count = 0
39+
6940
for _, (sentence, label) in enumerate(zip(sentences, labels)):
7041
label = np.array(list(map(lambda s: int(s), label.split(' '))))
7142
_, tag = cws.seg(sentence, model)
7243
cor_count, prec_count, recall_count = estimate_cws(tag, np.array(label))
7344
corr_count += cor_count
7445
re_count += recall_count
7546
total_count += prec_count
76-
# if(corr_count != prec_count):
77-
# print(cws.tags2words(sentence,tag))
78-
79-
# diff = np.subtract(tag,np.array(label))
80-
# if sum()
81-
# print(np.where(diff == 0))
82-
# corr_count += len(np.where(diff == 0)[0])
83-
# total_count += len(label)
8447
prec = corr_count / total_count
8548
recall = corr_count / re_count
49+
8650
print(prec)
8751
print(recall)
8852
print(2 * prec * recall / (prec + recall))
8953

9054

91-
def sigmoid(x):
92-
return 1 / (1 + np.exp(-x))
93-
94-
9555
if __name__ == '__main__':
96-
test_seg_dnn()
56+
# test_seg_dnn()
57+
test_seg_lstm()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /