Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

eagle705/pytorch-transformer-chatbot

Repository files navigation

PyTorch_Transformer_Chatbot

Simple Korean Generative Chatbot Implementation based on new PyTorch Transformer API (PyTorch v1.2 / Python 3.x)

transformer_fig

ToDo

  • Beam Search
  • Search hyperparams
  • Attention Visualization
  • Char-level transformer

Transformer API core logic

  • Padding masking이 매우 편함
  • decoder input의 future token을 못보게 하기 위한 masking은 함수로 제공함
  • Transformer의 input, output dim 순서는 [Sequnece, Batch, Embedding Dimension]으로 되어있어서 Transpose 해줘야함
  • 아쉽지만 Transformer API에서 Attention weight dict을 제공해주진 않음
def forward(self, enc_input: torch.Tensor, dec_input: torch.Tensor) -> torch.Tensor:
 x_enc_embed = self.input_embedding(enc_input.long())
 x_dec_embed = self.input_embedding(dec_input.long())
 # Masking
 src_key_padding_mask = enc_input == self.vocab.PAD_ID # tensor([[False, False, False, True, ..., True]])
 tgt_key_padding_mask = dec_input == self.vocab.PAD_ID
 memory_key_padding_mask = src_key_padding_mask
 tgt_mask = self.transfomrer.generate_square_subsequent_mask(dec_input.size(1))
 # einsum ref: https://pytorch.org/docs/stable/torch.html#torch.einsum
 # https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/
 x_enc_embed = torch.einsum('ijk->jik', x_enc_embed)
 x_dec_embed = torch.einsum('ijk->jik', x_dec_embed)
 # transformer ref: https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer
 feature = self.transfomrer(src = x_enc_embed,
 tgt = x_dec_embed,
 src_key_padding_mask = src_key_padding_mask,
 tgt_key_padding_mask = tgt_key_padding_mask,
 memory_key_padding_mask=memory_key_padding_mask,
 tgt_mask = tgt_mask.to(device)) # src: (S,N,E) tgt: (T,N,E)
 logits = self.proj_vocab_layer(feature)
 logits = torch.einsum('ijk->jik', logits)
 return logits

Experiments

  • train set에 대해서는 Overfit으로 95%의 정확도를 보이지만, val set에 대해서는 낮음 (예시로 공개하기 애매할정도)
  • padding 값은 acc, loss 계산에서 모두 제외함
input: [['나/NP', '를/JKO', '사랑/NNG', '한/XSA+ETM', '그/MM', '사람/NNG', '에게/JKB', '해/VV+EC', '줄/VX+ETM', '수/NNB', '있/VV', '는/ETM', '것/NNB', '<pad>', '<pad>'], ['맥주/NNG', '한/MM', '잔/NNG', '해야지/VV+EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred: [['천천히/MAG', '잊어/NNP', '가/JKS', '요/JX', './SF', '</s>', '들/VV', '노력/NNG', '요/JX', '은/JX', '도/JX', '이/JKS', '도/JX', '이/JKS', '이/JKS'], ['적당히/MAG', '드세요/VV+EP+EF', './SF', '</s>', '에/JKB', '에/JKB', '드세요/VV+EP+EF', '드세요/VV+EP+EF', '구경/NNG', '드세요/VV+EP+EF', '에/JKB', '드세요/VV+EP+EF', '가리/VV', '드세요/VV+EP+EF', '드세요/VV+EP+EF']]
real: [['천천히/MAG', '잊어/NNP', '가/JKS', '요/JX', './SF', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['적당히/MAG', '드세요/VV+EP+EF', './SF', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
  • Inference test 결과
    • train set에 있는건 잘 대답함
    • 얼추 잘(?) 입력하면 잘 대답함
    • "unk" 뜨거나 토큰 하나 바뀌어도 대답이 바뀜 ᅲ
문장을 입력하세요: 배고파
input: [['배고파/VA+EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred: [['얼른/MAG', '맛난/VA+ETM', '음식/NNG', '드세요/VV+EP+EF', './SF', '</s>']]
pred_str: 얼른 맛난 음식 드세요.
문장을 입력하세요: 너 누구야
input: [['너/NP', '누구/NP', '야/VCP+EF', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred: [['저/NP', '는/JX', '마음/NNG', '을/JKO', '이어주/VV', '는/ETM', '위/NNG', '로/JKB', '봇/NNG', '입니다/VCP+EF', './SF', '</s>']]
pred_str: 저는 마음을 이어주는 위로봇입니다.
문장을 입력하세요: 안녕
input: [['안녕/IC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred: [['안녕/NNG', '하/XSV', '세요/EP+EF', './SF', '</s>']]
pred_str: 안녕하세요.
문장을 입력하세요: 졸리다 이제 자야지
input: [['<unk>', '다/EF', '이제/MAG', '자/VV', '야지/EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred: [['너무/MAG', '걱정/NNG', '하/XSV', '지/EC', '마세요/VX+EP+EF', './SF', '</s>']]
pred_str: 너무 걱정하지 마세요.

실행순서

  • 테스트만 할경우 inference.py만 실행
python build_vocab.py # 사전 구축
python train.py # 학습
python inference.py # 테스트

Requirements

pip install mxnet
pip install gluonnlp
pip install konlpy
pip install python-mecab-ko
pip install chatspace
pip install tb-nightly
pip install future
pip install pathlib

Reference Repositories

About

PyTorch v1.2에서 생긴 Transformer API 를 이용한 간단한 Chitchat 챗봇

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

AltStyle によって変換されたページ (->オリジナル) /