Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1096f96

Browse files
committed
ChatBot 예제 최신 텐서플로에 맞게 수정
1 parent 57dfa7d commit 1096f96

12 files changed

+29
-29
lines changed

‎10 - RNN/ChatBot/chat.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ def run(self):
2626
line = sys.stdin.readline()
2727

2828
while line:
29-
print(self.get_replay(line.strip()))
29+
print(self._get_replay(line.strip()))
3030

3131
sys.stdout.write("\n> ")
3232
sys.stdout.flush()
3333

3434
line = sys.stdin.readline()
3535

36-
def decode(self, enc_input, dec_input):
36+
def _decode(self, enc_input, dec_input):
3737
if type(dec_input) is np.ndarray:
3838
dec_input = dec_input.tolist()
3939

@@ -46,7 +46,7 @@ def decode(self, enc_input, dec_input):
4646

4747
return self.model.predict(self.sess, [enc_input], [dec_input])
4848

49-
def get_replay(self, msg):
49+
def _get_replay(self, msg):
5050
enc_input = self.dialog.tokenizer(msg)
5151
enc_input = self.dialog.tokens_to_ids(enc_input)
5252
dec_input = []
@@ -57,7 +57,7 @@ def get_replay(self, msg):
5757
# 다만 상황에 따라서는 이런 방식이 더 유연할 수도 있을 듯
5858
curr_seq = 0
5959
for i in range(FLAGS.max_decode_len):
60-
outputs = self.decode(enc_input, dec_input)
60+
outputs = self._decode(enc_input, dec_input)
6161
if self.dialog.is_eos(outputs[0][curr_seq]):
6262
break
6363
elif self.dialog.is_defined(outputs[0][curr_seq]) is not True:
@@ -75,5 +75,6 @@ def main(_):
7575
chatbot = ChatBot(FLAGS.voc_path, FLAGS.train_dir)
7676
chatbot.run()
7777

78+
7879
if __name__ == "__main__":
7980
tf.app.run()

‎10 - RNN/ChatBot/dialog.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tensorflow as tf
33
import numpy as np
44
import re
5-
import codecs
65

76
from config import FLAGS
87

@@ -32,11 +31,11 @@ def decode(self, indices, string=False):
3231
tokens = [[self.vocab_list[i] for i in dec] for dec in indices]
3332

3433
if string:
35-
return self.decode_to_string(tokens[0])
34+
return self._decode_to_string(tokens[0])
3635
else:
3736
return tokens
3837

39-
def decode_to_string(self, tokens):
38+
def _decode_to_string(self, tokens):
4039
text = ' '.join(tokens)
4140
return text.strip()
4241

@@ -50,7 +49,7 @@ def is_eos(self, voc_id):
5049
def is_defined(self, voc_id):
5150
return voc_id in self._PRE_DEFINED_
5251

53-
def max_len(self, batch_set):
52+
def _max_len(self, batch_set):
5453
max_len_input = 0
5554
max_len_output = 0
5655

@@ -64,7 +63,7 @@ def max_len(self, batch_set):
6463

6564
return max_len_input, max_len_output + 1
6665

67-
def pad(self, seq, max_len, start=None, eos=None):
66+
def _pad(self, seq, max_len, start=None, eos=None):
6867
if start:
6968
padded_seq = [self._STA_ID_] + seq
7069
elif eos:
@@ -77,16 +76,16 @@ def pad(self, seq, max_len, start=None, eos=None):
7776
else:
7877
return padded_seq
7978

80-
def pad_left(self, seq, max_len):
79+
def _pad_left(self, seq, max_len):
8180
if len(seq) < max_len:
8281
return ([self._PAD_ID_] * (max_len - len(seq))) + seq
8382
else:
8483
return seq
8584

8685
def transform(self, input, output, input_max, output_max):
87-
enc_input = self.pad(input, input_max)
88-
dec_input = self.pad(output, output_max, start=True)
89-
target = self.pad(output, output_max, eos=True)
86+
enc_input = self._pad(input, input_max)
87+
dec_input = self._pad(output, output_max, start=True)
88+
target = self._pad(output, output_max, eos=True)
9089

9190
# 구글 방식으로 입력을 인코더에 역순으로 입력한다.
9291
enc_input.reverse()
@@ -117,7 +116,7 @@ def next_batch(self, batch_size):
117116

118117
# TODO: 구글처럼 버킷을 이용한 방식으로 변경
119118
# 간단하게 만들기 위해 구글처럼 버킷을 쓰지 않고 같은 배치는 같은 사이즈를 사용하도록 만듬
120-
max_len_input, max_len_output = self.max_len(batch_set)
119+
max_len_input, max_len_output = self._max_len(batch_set)
121120

122121
for i in range(0, len(batch_set) - 1, 2):
123122
enc, dec, tar = self.transform(batch_set[i], batch_set[i+1],
Binary file not shown.
Binary file not shown.

‎10 - RNN/ChatBot/model.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def __init__(self, vocab_size, n_hidden=128, n_layers=3):
2525
self.bias = tf.Variable(tf.zeros([self.vocab_size]), name="bias")
2626
self.global_step = tf.Variable(0, trainable=False, name="global_step")
2727

28-
self.build_model()
28+
self._build_model()
2929

3030
self.saver = tf.train.Saver(tf.global_variables())
3131

32-
def build_model(self):
33-
self.enc_input = tf.transpose(self.enc_input, [1, 0, 2])
34-
self.dec_input = tf.transpose(self.dec_input, [1, 0, 2])
32+
def _build_model(self):
33+
# self.enc_input = tf.transpose(self.enc_input, [1, 0, 2])
34+
# self.dec_input = tf.transpose(self.dec_input, [1, 0, 2])
3535

36-
enc_cell, dec_cell = self.build_cells()
36+
enc_cell, dec_cell = self._build_cells()
3737

3838
with tf.variable_scope('encode'):
3939
outputs, enc_states = tf.nn.dynamic_rnn(enc_cell, self.enc_input, dtype=tf.float32)
@@ -42,24 +42,24 @@ def build_model(self):
4242
outputs, dec_states = tf.nn.dynamic_rnn(dec_cell, self.dec_input, dtype=tf.float32,
4343
initial_state=enc_states)
4444

45-
self.logits, self.cost, self.train_op = self.build_ops(outputs, self.targets)
45+
self.logits, self.cost, self.train_op = self._build_ops(outputs, self.targets)
4646

4747
self.outputs = tf.argmax(self.logits, 2)
4848

49-
def cell(self, n_hidden, output_keep_prob):
50-
rnn_cell = tf.contrib.rnn.BasicRNNCell(self.n_hidden)
51-
rnn_cell = tf.contrib.rnn.DropoutWrapper(rnn_cell, output_keep_prob=output_keep_prob)
49+
def _cell(self, output_keep_prob):
50+
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(self.n_hidden)
51+
rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, output_keep_prob=output_keep_prob)
5252
return rnn_cell
5353

54-
def build_cells(self, output_keep_prob=0.5):
55-
enc_cell = tf.contrib.rnn.MultiRNNCell([self.cell(self.n_hidden, output_keep_prob)
54+
def _build_cells(self, output_keep_prob=0.5):
55+
enc_cell = tf.nn.rnn_cell.MultiRNNCell([self._cell(output_keep_prob)
5656
for _ in range(self.n_layers)])
57-
dec_cell = tf.contrib.rnn.MultiRNNCell([self.cell(self.n_hidden, output_keep_prob)
57+
dec_cell = tf.nn.rnn_cell.MultiRNNCell([self._cell(output_keep_prob)
5858
for _ in range(self.n_layers)])
5959

6060
return enc_cell, dec_cell
6161

62-
def build_ops(self, outputs, targets):
62+
def _build_ops(self, outputs, targets):
6363
time_steps = tf.shape(outputs)[1]
6464
outputs = tf.reshape(outputs, [-1, self.n_hidden])
6565

‎10 - RNN/ChatBot/model/checkpoint

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
model_checkpoint_path: "conversation.ckpt-5000"
2-
all_model_checkpoint_paths: "conversation.ckpt-5000"
1+
model_checkpoint_path: "conversation.ckpt-10000"
2+
all_model_checkpoint_paths: "conversation.ckpt-10000"
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /