Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bcd62a5

Browse files
Merge pull request #19 from scoutbee/add_transformer
Add transformer
2 parents 4a99292 + 14b2449 commit bcd62a5

14 files changed

+2375
-227
lines changed

‎2_embeddings.ipynb‎

Lines changed: 397 additions & 227 deletions
Large diffs are not rendered by default.

‎6_transformer_translation.ipynb‎

Lines changed: 1266 additions & 0 deletions
Large diffs are not rendered by default.
File renamed without changes.

‎images/beam-search.svg‎

Lines changed: 163 additions & 0 deletions
Loading[フレーム]

‎images/encoder_decoder_stack.png‎

125 KB
Loading[フレーム]

‎images/multi_head_attention.png‎

35.7 KB
Loading[フレーム]
20.4 KB
Loading[フレーム]

‎images/transformer.png‎

41.6 KB
Loading[フレーム]

‎transformer/__init__.py‎

Whitespace-only changes.

‎transformer/batch.py‎

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from collections import Counter
2+
3+
import numpy as np
4+
import pandas as pd
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from nltk.tokenize import wordpunct_tokenize
9+
from torch import optim
10+
from torch.autograd import Variable
11+
from torch.nn.utils.rnn import pad_sequence
12+
from torch.utils.data import Dataset, DataLoader, Subset
13+
14+
15+
def tokenize(text):
16+
"""Turn text into discrete tokens.
17+
18+
Remove tokens that are not words.
19+
"""
20+
text = text.lower()
21+
tokens = wordpunct_tokenize(text)
22+
23+
# Only keep words
24+
tokens = [token for token in tokens
25+
if all(char.isalpha() for char in token)]
26+
27+
return tokens
28+
29+
30+
class EnglishFrenchTranslations(Dataset):
31+
def __init__(self, path, max_vocab, max_seq_len):
32+
self.max_vocab = max_vocab
33+
34+
# Extra tokens to add
35+
self.padding_token = '<PAD>'
36+
self.start_of_sequence_token = '<SOS>'
37+
self.end_of_sequence_token = '<EOS>'
38+
self.unknown_word_token = '<UNK>'
39+
self.max_seq_len = max_seq_len
40+
41+
# Helper function
42+
self.flatten = lambda x: [sublst for lst in x for sublst in lst]
43+
44+
# Load the data into a DataFrame
45+
df = pd.read_csv(path, names=['english', 'french'], sep='\t')
46+
47+
# filter out too long sequences
48+
df = self.filter_seq_len(df, max_len=self.max_seq_len)
49+
50+
# Tokenize inputs (English) and targets (French)
51+
self.tokenize_df(df)
52+
53+
# To reduce computational complexity, replace rare words with <UNK>
54+
self.replace_rare_tokens(df)
55+
56+
# Prepare variables with mappings of tokens to indices
57+
self.create_token2idx(df)
58+
59+
# Remove sequences with mostly <UNK>
60+
df = self.remove_mostly_unk(df)
61+
62+
# Every sequence (input and target) should start with <SOS>
63+
# and end with <EOS>
64+
self.add_start_and_end_to_tokens(df)
65+
66+
# Convert tokens to indices
67+
self.tokens_to_indices(df)
68+
69+
def __getitem__(self, idx):
70+
"""Return example at index idx."""
71+
return self.indices_pairs[idx][0], self.indices_pairs[idx][1]
72+
73+
def tokenize_df(self, df):
74+
"""Turn inputs and targets into tokens."""
75+
df['tokens_inputs'] = df.english.apply(tokenize)
76+
df['tokens_targets'] = df.french.apply(tokenize)
77+
78+
def replace_rare_tokens(self, df):
79+
"""Replace rare tokens with <UNK>."""
80+
common_tokens_inputs = self.get_most_common_tokens(
81+
df.tokens_inputs.tolist(),
82+
)
83+
common_tokens_targets = self.get_most_common_tokens(
84+
df.tokens_targets.tolist(),
85+
)
86+
87+
df.loc[:, 'tokens_inputs'] = df.tokens_inputs.apply(
88+
lambda tokens: [token if token in common_tokens_inputs
89+
else self.unknown_word_token for token in tokens]
90+
)
91+
df.loc[:, 'tokens_targets'] = df.tokens_targets.apply(
92+
lambda tokens: [token if token in common_tokens_targets
93+
else self.unknown_word_token for token in tokens]
94+
)
95+
96+
def get_most_common_tokens(self, tokens_series):
97+
"""Return the max_vocab most common tokens."""
98+
all_tokens = self.flatten(tokens_series)
99+
# Substract 4 for <PAD>, <SOS>, <EOS>, and <UNK>
100+
common_tokens = set(list(zip(*Counter(all_tokens).most_common(
101+
self.max_vocab - 4)))[0])
102+
return common_tokens
103+
104+
def remove_mostly_unk(self, df, threshold=0.99):
105+
"""Remove sequences with mostly <UNK>."""
106+
calculate_ratio = (
107+
lambda tokens: sum(1 for token in tokens if token != '<UNK>')
108+
/ len(tokens) > threshold
109+
)
110+
df = df[df.tokens_inputs.apply(calculate_ratio)]
111+
df = df[df.tokens_targets.apply(calculate_ratio)]
112+
return df
113+
114+
def filter_seq_len(self, df, max_len=100):
115+
mask = (df['english'].str.count(' ') < max_len) & (df['french'].str.count(' ') < max_len)
116+
return df.loc[mask]
117+
118+
def create_token2idx(self, df):
119+
"""Create variables with mappings from tokens to indices."""
120+
unique_tokens_inputs = set(self.flatten(df.tokens_inputs))
121+
unique_tokens_targets = set(self.flatten(df.tokens_targets))
122+
123+
for token in reversed([
124+
self.padding_token,
125+
self.start_of_sequence_token,
126+
self.end_of_sequence_token,
127+
self.unknown_word_token,
128+
]):
129+
if token in unique_tokens_inputs:
130+
unique_tokens_inputs.remove(token)
131+
if token in unique_tokens_targets:
132+
unique_tokens_targets.remove(token)
133+
134+
unique_tokens_inputs = sorted(list(unique_tokens_inputs))
135+
unique_tokens_targets = sorted(list(unique_tokens_targets))
136+
137+
# Add <PAD>, <SOS>, <EOS>, and <UNK> tokens
138+
for token in reversed([
139+
self.padding_token,
140+
self.start_of_sequence_token,
141+
self.end_of_sequence_token,
142+
self.unknown_word_token,
143+
]):
144+
145+
unique_tokens_inputs = [token] + unique_tokens_inputs
146+
unique_tokens_targets = [token] + unique_tokens_targets
147+
148+
self.token2idx_inputs = {token: idx for idx, token
149+
in enumerate(unique_tokens_inputs)}
150+
self.idx2token_inputs = {idx: token for token, idx
151+
in self.token2idx_inputs.items()}
152+
153+
self.token2idx_targets = {token: idx for idx, token
154+
in enumerate(unique_tokens_targets)}
155+
self.idx2token_targets = {idx: token for token, idx
156+
in self.token2idx_targets.items()}
157+
158+
def add_start_and_end_to_tokens(self, df):
159+
"""Add <SOS> and <EOS> tokens to the end of every input and output."""
160+
df.loc[:, 'tokens_inputs'] = (
161+
[self.start_of_sequence_token]
162+
+ df.tokens_inputs
163+
+ [self.end_of_sequence_token]
164+
)
165+
df.loc[:, 'tokens_targets'] = (
166+
[self.start_of_sequence_token]
167+
+ df.tokens_targets
168+
+ [self.end_of_sequence_token]
169+
)
170+
171+
def tokens_to_indices(self, df):
172+
"""Convert tokens to indices."""
173+
df['indices_inputs'] = df.tokens_inputs.apply(
174+
lambda tokens: [self.token2idx_inputs[token] for token in tokens])
175+
df['indices_targets'] = df.tokens_targets.apply(
176+
lambda tokens: [self.token2idx_targets[token] for token in tokens])
177+
178+
self.indices_pairs = list(zip(df.indices_inputs, df.indices_targets))
179+
180+
def __len__(self):
181+
return len(self.indices_pairs)
182+
183+
184+
def collate(batch, src_pad, trg_pad, device):
185+
inputs = [torch.LongTensor(item[0]) for item in batch]
186+
targets = [torch.LongTensor(item[1]) for item in batch]
187+
188+
# Pad sequencse so that they are all the same length (within one minibatch)
189+
padded_inputs = pad_sequence(inputs, padding_value=src_pad, batch_first=True)
190+
padded_targets = pad_sequence(targets, padding_value=trg_pad, batch_first=True)
191+
192+
# Sort by length for CUDA optimizations
193+
lengths = torch.LongTensor([len(x) for x in inputs])
194+
lengths, permutation = lengths.sort(dim=0, descending=True)
195+
196+
return padded_inputs[permutation].to(device), padded_targets[permutation].to(device), lengths.to(device)
197+
198+
199+
def no_peak_mask(size):
200+
mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')
201+
mask = Variable(torch.from_numpy(mask) == 0)
202+
return mask
203+
204+
205+
def create_masks(src, trg, src_pad_idx, trg_pad_idx):
206+
src_mask = (src != src_pad_idx).unsqueeze(-2)
207+
if trg is not None:
208+
trg_mask = (trg != trg_pad_idx).unsqueeze(-2)
209+
size = trg.size(1) # get seq_len for matrix
210+
np_mask = no_peak_mask(size).to(trg_mask.device)
211+
trg_mask = trg_mask & np_mask
212+
else:
213+
trg_mask = None
214+
return src_mask, trg_mask

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /