11# -*- coding: UTF-8 -*- 
22import  numpy  as  np 
3- #import jieba 
43from  seg_dnn  import  SegDNN 
54from  seg_lstm  import  SegLSTM 
65from  utils  import  estimate_cws 
76import  constant 
87
98
109def  test_seg_dnn ():
11-  ''' 
1210 cws  =  SegDNN (constant .VOCAB_SIZE , 50 , constant .DNN_SKIP_WINDOW )
13-  sentence = '迈向充满希望的新世纪' 
14-  # model = 'tmp/4.29-w3-normal/model15.ckpt' 
15-  model = 'tmp/4.29-100/model99.ckpt' 
1611 model  =  'tmp/model0.ckpt' 
1712 print (cws .seg ('小明来自南京师范大学' , model ))
1813 print (cws .seg ('小明是上海理工大学的学生' , model ))
1914 print (cws .seg ('小明是清华大学的学生' , model ))
2015 print (cws .seg ('我爱北京天安门' , model ))
2116 print (cws .seg ('上海理工大学' , model ))
2217 print (cws .seg ('上海海洋大学' ))
23-  print(cws.seg(sentence, model)) 
24-  #print('/'.join(jieba.cut('小明是上海理工大学的学生'))) 
25-  seq = cws.index2seq(cws.sentence2index(sentence)) 
26-  seq = np.array(seq, dtype=np.int32).flatten() 
27-  ''' 
28-  seg  =  SegLSTM ()
29-  # seg.train_exe() 
30-  #print(seg.seg('我爱北京天安门')) 
31-  #print(seg.seg('小明来自南京师范大学')) 
32-  #print(seg.seg('小明是上海理工大学的学生')) 
33-  print (seg .seg ('小明来自南京师范大学' ))
34-  print (seg .seg ('小明是上海理工大学的学生' ))
35-  test (seg ,'tmp/lstm-model0.ckpt' )
36-  # print(seq) 
37-  # cal_val(seq) 
38-  # print(cws.seg('2015世界旅游小姐大赛山东赛区冠军总决赛在威海举行',model)) 
18+  print (cws .seg ('迈向充满希望的新世纪' , model ))
3919
4020
41- def  cal_val (seq ):
42-  embeddings  =  np .load ('data/dnn/embeddings.npy' )
43-  w2  =  np .load ('data/dnn/w2.npy' )
44-  w3  =  np .load ('data/dnn/w3.npy' )
45-  b2  =  np .load ('data/dnn/b2.npy' )
46-  b3  =  np .load ('data/dnn/b3.npy' )
47-  # b2 = np.expand_dims(b2.flatten(),0) 
48-  # A = np.load('data/dnn/A.npy') 
49-  # init_A = np.load('data/dnn/init_A.npy') 
50-  # print(w2) 
51-  # x = np.reshape(embeddings[seq],[10,150]) 
52-  # print(x[0]) 
53-  # b2 = np.tile(b2,[10]) 
54-  # print(np.matmul(w2,x.T)) 
55-  # print(b2.shape) 
56-  # val = np.matmul(w3,sigmoid(np.add(np.matmul(w2,x.T),b2)))+b3 
57-  # val = w3*sigmoid(w2*x.T+b2)+b3 
58-  # print(val.T) 
21+ def  test_seg_lstm ():
22+  seg  =  SegLSTM ()
23+  model  =  'tmp/lstm-model1.ckpt' 
24+  print (seg .seg ('小明来自南京师范大学' , model , debug = True ))
25+  print (seg .seg ('小明是上海理工大学的学生' , model ))
26+  print (seg .seg ('迈向充满希望的新世纪' , model ))
27+  print (seg .seg ('2015世界旅游小姐大赛山东赛区冠军总决赛在威海举行' , model ))
28+  test (seg , model )
5929
6030
6131def  test (cws , model ):
@@ -66,31 +36,22 @@ def test(cws, model):
6636 corr_count  =  0 
6737 re_count  =  0 
6838 total_count  =  0 
39+ 6940 for  _ , (sentence , label ) in  enumerate (zip (sentences , labels )):
7041 label  =  np .array (list (map (lambda  s : int (s ), label .split (' ' ))))
7142 _ , tag  =  cws .seg (sentence , model )
7243 cor_count , prec_count , recall_count  =  estimate_cws (tag , np .array (label ))
7344 corr_count  +=  cor_count 
7445 re_count  +=  recall_count 
7546 total_count  +=  prec_count 
76-  # if(corr_count != prec_count): 
77-  # print(cws.tags2words(sentence,tag)) 
78- 79-  # diff = np.subtract(tag,np.array(label)) 
80-  # if sum() 
81-  # print(np.where(diff == 0)) 
82-  # corr_count += len(np.where(diff == 0)[0]) 
83-  # total_count += len(label) 
8447 prec  =  corr_count  /  total_count 
8548 recall  =  corr_count  /  re_count 
49+ 8650 print (prec )
8751 print (recall )
8852 print (2  *  prec  *  recall  /  (prec  +  recall ))
8953
9054
91- def  sigmoid (x ):
92-  return  1  /  (1  +  np .exp (- x ))
93- 94- 9555if  __name__  ==  '__main__' :
96-  test_seg_dnn ()
56+  # test_seg_dnn() 
57+  test_seg_lstm ()
0 commit comments