|
1 | | -""" |
2 | | -Test the TextRNN class |
3 | | -2016年12月22日 |
4 | | -""" |
5 | | -import os |
6 | | -import sys |
7 | | -import numpy as np |
8 | | -import tensorflow as tf |
9 | | -from sklearn.model_selection import train_test_split |
10 | | -from tensorflow.contrib import learn |
11 | | - |
12 | | -from data_helpers import load_data_and_labels, batch_iter |
13 | | -from text_cnn import TextCNN |
14 | | - |
15 | | - |
16 | | -# Load original data |
17 | | -path = sys.path[0] |
18 | | -pos_filename = path + "/data/rt-polarity.pos" |
19 | | -neg_filename = path + "/data/rt-polarity.neg" |
20 | | - |
21 | | -X_data, y_data = load_data_and_labels(pos_filename, neg_filename) |
22 | | -max_document_length = max([len(sen.split(" ")) for sen in X_data]) |
23 | | -print("Max_document_length:,", max_document_length) |
24 | | -# Create the vacabulary |
25 | | -vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length) |
26 | | -# The idx data |
27 | | -x = np.array(list(vocab_processor.fit_transform(X_data)), dtype=np.float32) |
28 | | -y = np.array(y_data, dtype=np.int32) |
29 | | -vocabulary_size = len(vocab_processor.vocabulary_) |
30 | | -print("The size of vocabulary:", vocabulary_size) |
31 | | -# Split the data |
32 | | -X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1111) |
33 | | -print("X_train shape {0}, y_train shape {1}".format(X_train.shape, y_train.shape)) |
34 | | -print("X_test shape {0}, y_test shape {1}".format(X_test.shape, y_test.shape)) |
35 | | - |
36 | | -# The parameters of RNN |
37 | | -seq_len = X_train.shape[1] |
38 | | -vocab_size = vocabulary_size |
39 | | -embedding_size = 128 |
40 | | -filter_sizes = [2, 3, 4] |
41 | | -num_filters = 128 |
42 | | -num_classes = y_train.shape[1] |
43 | | -l2_reg_lambda = 0.0 |
44 | | - |
45 | | -# Construct RNN model |
46 | | -text_rnn_model = TextCNN(seq_len=seq_len, vocab_size=vocab_size, embedding_size=embedding_size, filter_sizes= |
47 | | - filter_sizes, num_filters=num_filters, num_classes=num_classes) |
48 | | -loss = text_rnn_model.loss |
49 | | -train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) |
50 | | -accuracy = text_rnn_model.accuracy |
51 | | -# The parameters for training |
52 | | -batch_size = 64 |
53 | | -training_epochs = 10 |
54 | | -dispaly_every = 1 |
55 | | -dropout_keep_prob = 0.5 |
56 | | - |
57 | | -batch_num = int(X_train.shape[0]/batch_size) |
58 | | - |
59 | | -sess = tf.Session() |
60 | | -sess.run(tf.global_variables_initializer()) |
61 | | -print("Starting training...") |
62 | | -for epoch in range(training_epochs): |
63 | | - avg_cost = 0 |
64 | | - for batch in range(batch_num): |
65 | | - _, cost = sess.run([train_op, loss], feed_dict={text_rnn_model.x: X_train[batch*batch_size:(batch+1)*batch_size], |
66 | | - text_rnn_model.y: y_train[batch*batch_size:(batch+1)*batch_size], |
67 | | - text_rnn_model.dropout_keep_prob:dropout_keep_prob}) |
68 | | - avg_cost += cost |
69 | | - if epoch % dispaly_every == 0: |
70 | | - cost, acc = sess.run([loss, accuracy], feed_dict={text_rnn_model.x: X_test, |
71 | | - text_rnn_model.y: y_test, |
72 | | - text_rnn_model.dropout_keep_prob: 1.0}) |
73 | | - print("\nEpoch {0} : loss {1}, accuracy {2}".format(epoch, cost, acc)) |
74 | | - |
| 1 | +""" |
| 2 | +Test the TextRNN class |
| 3 | +2016年12月22日 |
| 4 | +""" |
| 5 | +import os |
| 6 | +import sys |
| 7 | +import numpy as np |
| 8 | +import tensorflow as tf |
| 9 | +from sklearn.model_selection import train_test_split |
| 10 | +from tensorflow.contrib import learn |
| 11 | + |
| 12 | +from data_helpers import load_data_and_labels, batch_iter |
| 13 | +from text_cnn import TextCNN |
| 14 | +importpudb;pu.db |
| 15 | + |
| 16 | +# Load original data |
| 17 | +path = sys.path[0] |
| 18 | +pos_filename = path + "/data/rt-polarity.pos" |
| 19 | +neg_filename = path + "/data/rt-polarity.neg" |
| 20 | + |
| 21 | +X_data, y_data = load_data_and_labels(pos_filename, neg_filename) |
| 22 | +max_document_length = max([len(sen.split(" ")) for sen in X_data]) |
| 23 | +print("Max_document_length:,", max_document_length) |
| 24 | +# Create the vacabulary |
| 25 | +vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length) |
| 26 | +# The idx data |
| 27 | +x = np.array(list(vocab_processor.fit_transform(X_data)), dtype=np.float32) |
| 28 | +y = np.array(y_data, dtype=np.int32) |
| 29 | +vocabulary_size = len(vocab_processor.vocabulary_) |
| 30 | +print("The size of vocabulary:", vocabulary_size) |
| 31 | +# Split the data |
| 32 | +X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1111) |
| 33 | +print("X_train shape {0}, y_train shape {1}".format(X_train.shape, y_train.shape)) |
| 34 | +print("X_test shape {0}, y_test shape {1}".format(X_test.shape, y_test.shape)) |
| 35 | + |
| 36 | +# The parameters of RNN |
| 37 | +seq_len = X_train.shape[1] |
| 38 | +vocab_size = vocabulary_size |
| 39 | +embedding_size = 128 |
| 40 | +filter_sizes = [2, 3, 4] |
| 41 | +num_filters = 128 |
| 42 | +num_classes = y_train.shape[1] |
| 43 | +l2_reg_lambda = 0.0 |
| 44 | + |
| 45 | +# Construct RNN model |
| 46 | +text_rnn_model = TextCNN(seq_len=seq_len, vocab_size=vocab_size, embedding_size=embedding_size, filter_sizes= |
| 47 | + filter_sizes, num_filters=num_filters, num_classes=num_classes) |
| 48 | +loss = text_rnn_model.loss |
| 49 | +train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) |
| 50 | +accuracy = text_rnn_model.accuracy |
| 51 | +# The parameters for training |
| 52 | +batch_size = 64 |
| 53 | +training_epochs = 10 |
| 54 | +dispaly_every = 1 |
| 55 | +dropout_keep_prob = 0.5 |
| 56 | + |
| 57 | +batch_num = int(X_train.shape[0]/batch_size) |
| 58 | + |
| 59 | +sess = tf.Session() |
| 60 | +sess.run(tf.global_variables_initializer()) |
| 61 | +print("Starting training...") |
| 62 | +for epoch in range(training_epochs): |
| 63 | + avg_cost = 0 |
| 64 | + for batch in range(batch_num): |
| 65 | + _, cost = sess.run([train_op, loss], feed_dict={text_rnn_model.x: X_train[batch*batch_size:(batch+1)*batch_size], |
| 66 | + text_rnn_model.y: y_train[batch*batch_size:(batch+1)*batch_size], |
| 67 | + text_rnn_model.dropout_keep_prob:dropout_keep_prob}) |
| 68 | + avg_cost += cost |
| 69 | + if epoch % dispaly_every == 0: |
| 70 | + cost, acc = sess.run([loss, accuracy], feed_dict={text_rnn_model.x: X_test, |
| 71 | + text_rnn_model.y: y_test, |
| 72 | + text_rnn_model.dropout_keep_prob: 1.0}) |
| 73 | + print("\nEpoch {0} : loss {1}, accuracy {2}".format(epoch, cost, acc)) |
| 74 | + |
0 commit comments