Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4b18f34

Browse files
author
Susannah Klaneček
committed
Add char-RNN notebook
1 parent 93d380e commit 4b18f34

File tree

2 files changed

+397
-0
lines changed

2 files changed

+397
-0
lines changed

‎char_text_generation.ipynb‎

Lines changed: 397 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,397 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "char_text_generation.ipynb",
7+
"version": "0.3.2",
8+
"provenance": [],
9+
"collapsed_sections": []
10+
},
11+
"language_info": {
12+
"codemirror_mode": {
13+
"name": "ipython",
14+
"version": 3
15+
},
16+
"file_extension": ".py",
17+
"mimetype": "text/x-python",
18+
"name": "python",
19+
"nbconvert_exporter": "python",
20+
"pygments_lexer": "ipython3",
21+
"version": "3.7.3"
22+
},
23+
"kernelspec": {
24+
"name": "python3",
25+
"display_name": "Python 3"
26+
},
27+
"accelerator": "GPU"
28+
},
29+
"cells": [
30+
{
31+
"cell_type": "markdown",
32+
"metadata": {
33+
"id": "lRyoQrbvtdLy",
34+
"colab_type": "text"
35+
},
36+
"source": [
37+
"# char-RNN: Character-level text generation\n",
38+
"\n",
39+
"From day-to-day weather patterns to stock market fluctuations, from music to novels; sequential data is everywhere."
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"metadata": {
45+
"id": "b5ztGGH--rfg",
46+
"colab_type": "code",
47+
"colab": {}
48+
},
49+
"source": [
50+
"!pip install boltons -q"
51+
],
52+
"execution_count": 0,
53+
"outputs": []
54+
},
55+
{
56+
"cell_type": "code",
57+
"metadata": {
58+
"id": "GJ_sI6SI-UIR",
59+
"colab_type": "code",
60+
"colab": {}
61+
},
62+
"source": [
63+
"import string\n",
64+
"from pathlib import Path\n",
65+
"from textwrap import wrap\n",
66+
"\n",
67+
"\n",
68+
"import numpy as np\n",
69+
"import pandas as pd\n",
70+
"from torchviz import make_dot, make_dot_from_trace\n",
71+
"from boltons.iterutils import windowed\n",
72+
"from tqdm import tqdm, tqdm_notebook\n",
73+
"\n",
74+
"import torch\n",
75+
"import torch.nn as nn\n",
76+
"import torch.nn.functional as F\n",
77+
"from torch import optim\n",
78+
"from torch.optim.lr_scheduler import CosineAnnealingLR\n",
79+
"from torch.utils.data import Dataset, DataLoader\n",
80+
"from torch.utils.data.dataset import random_split\n",
81+
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
82+
"\n",
83+
"from google_drive_downloader import GoogleDriveDownloader as gdd"
84+
],
85+
"execution_count": 0,
86+
"outputs": []
87+
},
88+
{
89+
"cell_type": "code",
90+
"metadata": {
91+
"id": "tXUE8n_v-UIs",
92+
"colab_type": "code",
93+
"colab": {}
94+
},
95+
"source": [
96+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
97+
"device"
98+
],
99+
"execution_count": 0,
100+
"outputs": []
101+
},
102+
{
103+
"cell_type": "code",
104+
"metadata": {
105+
"id": "NKw4lpSO-UI2",
106+
"colab_type": "code",
107+
"colab": {}
108+
},
109+
"source": [
110+
"DATA_PATH = 'data/weight_loss/articles.jsonl'\n",
111+
"if not Path(DATA_PATH).is_file():\n",
112+
" gdd.download_file_from_google_drive(\n",
113+
" file_id='1mafPreWzE-FyLI0K-MUsXPcnUI0epIcI',\n",
114+
" dest_path='data/weight_loss/weight_loss_articles.zip',\n",
115+
" unzip=True,\n",
116+
" )"
117+
],
118+
"execution_count": 0,
119+
"outputs": []
120+
},
121+
{
122+
"cell_type": "code",
123+
"metadata": {
124+
"id": "Fe1X3_lY-UI9",
125+
"colab_type": "code",
126+
"colab": {}
127+
},
128+
"source": [
129+
"def load_data(path, sequence_length=125):\n",
130+
" texts = pd.read_json(path).text.sample(100).str.lower().tolist()\n",
131+
" chars_windowed = [list(windowed(text, sequence_length)) for text in texts]\n",
132+
" all_chars_windowed = [sublst for lst in chars_windowed for sublst in lst]\n",
133+
" filtered_good_chars = [\n",
134+
" sequence for sequence in tqdm_notebook(all_chars_windowed) \n",
135+
" if all(char in string.printable for char in sequence)\n",
136+
" ]\n",
137+
" return filtered_good_chars\n",
138+
"\n",
139+
"\n",
140+
"def get_unique_chars(sequences):\n",
141+
" return {sublst for lst in sequences for sublst in lst}\n",
142+
"\n",
143+
"\n",
144+
"def create_char2idx(sequences):\n",
145+
" unique_chars = get_unique_chars(sequences)\n",
146+
" return {char: idx for idx, char in enumerate(sorted(unique_chars))}\n",
147+
"\n",
148+
"\n",
149+
"def encode_sequence(sequence, char2idx):\n",
150+
" return [char2idx[char] for char in sequence]\n",
151+
"\n",
152+
"\n",
153+
"def encode_sequences(sequences, char2idx):\n",
154+
" return np.array([\n",
155+
" encode_sequence(sequence, char2idx) \n",
156+
" for sequence in tqdm_notebook(sequences)\n",
157+
" ])\n",
158+
"\n",
159+
"\n",
160+
"class Sequences(Dataset):\n",
161+
" def __init__(self, path, sequence_length=125):\n",
162+
" self.sequences = load_data(DATA_PATH, sequence_length=sequence_length)\n",
163+
" self.vocab_size = len(get_unique_chars(self.sequences))\n",
164+
" self.char2idx = create_char2idx(self.sequences)\n",
165+
" self.idx2char = {idx: char for char, idx in self.char2idx.items()}\n",
166+
" self.encoded = encode_sequences(self.sequences, self.char2idx)\n",
167+
" \n",
168+
" def __getitem__(self, i):\n",
169+
" return self.encoded[i, :-1], self.encoded[i, 1:]\n",
170+
" \n",
171+
" def __len__(self):\n",
172+
" return len(self.encoded)"
173+
],
174+
"execution_count": 0,
175+
"outputs": []
176+
},
177+
{
178+
"cell_type": "code",
179+
"metadata": {
180+
"id": "BZxQzaME-UJG",
181+
"colab_type": "code",
182+
"colab": {}
183+
},
184+
"source": [
185+
"dataset = Sequences(DATA_PATH, sequence_length=128)\n",
186+
"len(dataset)\n",
187+
"train_loader = DataLoader(dataset, batch_size=4096)"
188+
],
189+
"execution_count": 0,
190+
"outputs": []
191+
},
192+
{
193+
"cell_type": "markdown",
194+
"metadata": {
195+
"id": "XKupaX61za9e",
196+
"colab_type": "text"
197+
},
198+
"source": [
199+
"## GRU: Gated Recurrent Unit\n",
200+
"\n",
201+
"![](images/char-rnn.png)\n",
202+
"\n",
203+
"The following function is computed for each element in the input sequence in the GRU cell:\n",
204+
"\n",
205+
"$$\n",
206+
"\\begin{array}{ll}\n",
207+
" r_t = \\sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\\\\n",
208+
" z_t = \\sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\\\\n",
209+
" n_t = \\tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\\\\n",
210+
" h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}\n",
211+
"\\end{array}\n",
212+
"$$\n",
213+
"\n",
214+
"where $h_t$ is the hidden state at time $t,ドル $x_t$ is the input at time $t,ドル $h_{(t-1)}$ is the hidden state of the layer at time $t-1$ or the initial hidden state at time 0ドル,ドル and $r_t,ドル $z_t,ドル $n_t$ are the reset, update, and new gates, respectively. $\\sigma$ is the sigmoid function, and $*$ is the Hadamard product."
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"metadata": {
220+
"colab_type": "code",
221+
"id": "_gAAAtscOJf-",
222+
"colab": {}
223+
},
224+
"source": [
225+
"class RNN(nn.Module):\n",
226+
" def __init__(\n",
227+
" self,\n",
228+
" vocab_size,\n",
229+
" embedding_dimension=100,\n",
230+
" hidden_size=128, \n",
231+
" n_layers=1,\n",
232+
" device='cpu',\n",
233+
" ):\n",
234+
" super(RNN, self).__init__()\n",
235+
" self.n_layers = n_layers\n",
236+
" self.hidden_size = hidden_size\n",
237+
" self.device = device\n",
238+
" \n",
239+
" self.encoder = nn.Embedding(vocab_size, embedding_dimension)\n",
240+
" self.rnn = nn.GRU(\n",
241+
" embedding_dimension,\n",
242+
" hidden_size,\n",
243+
" num_layers=n_layers,\n",
244+
" batch_first=True,\n",
245+
" )\n",
246+
" self.decoder = nn.Linear(hidden_size, vocab_size)\n",
247+
" \n",
248+
" def init_hidden(self, batch_size):\n",
249+
" return torch.randn(self.n_layers, batch_size, self.hidden_size).to(self.device)\n",
250+
" \n",
251+
" def forward(self, input_, hidden):\n",
252+
" encoded = self.encoder(input_)\n",
253+
" output, hidden = self.rnn(encoded.unsqueeze(1), hidden)\n",
254+
" output = self.decoder(output.squeeze(1))\n",
255+
" return output, hidden"
256+
],
257+
"execution_count": 0,
258+
"outputs": []
259+
},
260+
{
261+
"cell_type": "code",
262+
"metadata": {
263+
"id": "UrC85qKz-UJT",
264+
"colab_type": "code",
265+
"colab": {}
266+
},
267+
"source": [
268+
"model = RNN(vocab_size=dataset.vocab_size, device=device).to(device)\n",
269+
"\n",
270+
"criterion = nn.CrossEntropyLoss()\n",
271+
"optimizer = optim.Adam(\n",
272+
" filter(lambda p: p.requires_grad, model.parameters()),\n",
273+
" lr=0.001,\n",
274+
")\n",
275+
"scheduler = CosineAnnealingLR(optimizer, 1)"
276+
],
277+
"execution_count": 0,
278+
"outputs": []
279+
},
280+
{
281+
"cell_type": "code",
282+
"metadata": {
283+
"id": "rw0OpXVz9S7n",
284+
"colab_type": "code",
285+
"colab": {}
286+
},
287+
"source": [
288+
"print(model)\n",
289+
"print()\n",
290+
"print('Trainable parameters:')\n",
291+
"print('\\n'.join([' * ' + x[0] for x in model.named_parameters() if x[1].requires_grad]))"
292+
],
293+
"execution_count": 0,
294+
"outputs": []
295+
},
296+
{
297+
"cell_type": "code",
298+
"metadata": {
299+
"id": "Edy_iSXh-UJZ",
300+
"colab_type": "code",
301+
"colab": {}
302+
},
303+
"source": [
304+
"model.train()\n",
305+
"train_losses = []\n",
306+
"for epoch in range(50):\n",
307+
" progress_bar = tqdm_notebook(train_loader, leave=False)\n",
308+
" losses = []\n",
309+
" total = 0\n",
310+
" for inputs, targets in progress_bar:\n",
311+
" batch_size = inputs.size(0)\n",
312+
" hidden = model.init_hidden(batch_size)\n",
313+
"\n",
314+
" model.zero_grad()\n",
315+
" \n",
316+
" loss = 0\n",
317+
" for char_idx in range(inputs.size(1)):\n",
318+
" output, hidden = model(inputs[:, char_idx].to(device), hidden)\n",
319+
" loss += criterion(output, targets[:, char_idx].to(device))\n",
320+
"\n",
321+
" loss.backward()\n",
322+
" \n",
323+
" torch.nn.utils.clip_grad_norm_(model.parameters(), 3)\n",
324+
"\n",
325+
" optimizer.step()\n",
326+
" scheduler.step()\n",
327+
" \n",
328+
" avg_loss = loss.item() / inputs.size(1)\n",
329+
" \n",
330+
" progress_bar.set_description(f'Loss: {avg_loss:.3f}')\n",
331+
" \n",
332+
" losses.append(avg_loss)\n",
333+
" total += 1\n",
334+
" \n",
335+
" epoch_loss = sum(losses) / total\n",
336+
" train_losses.append(epoch_loss)\n",
337+
" \n",
338+
" tqdm.write(f'Epoch #{epoch + 1}\\tTrain Loss: {epoch_loss:.3f}')"
339+
],
340+
"execution_count": 0,
341+
"outputs": []
342+
},
343+
{
344+
"cell_type": "code",
345+
"metadata": {
346+
"id": "TsxjlCe9-UJd",
347+
"colab_type": "code",
348+
"colab": {}
349+
},
350+
"source": [
351+
"def pretty_print(text):\n",
352+
" \"\"\"Wrap text for nice printing.\"\"\"\n",
353+
" to_print = ''\n",
354+
" for paragraph in text.split('\\n'):\n",
355+
" to_print += '\\n'.join(wrap(paragraph))\n",
356+
" to_print += '\\n'\n",
357+
" print(to_print)\n",
358+
"\n",
359+
"\n",
360+
"temperature = 0.9\n",
361+
"\n",
362+
"model.eval()\n",
363+
"seed = '\\n'\n",
364+
"text = ''\n",
365+
"with torch.no_grad():\n",
366+
" batch_size = 1\n",
367+
" hidden = model.init_hidden(batch_size)\n",
368+
" last_char = dataset.char2idx[seed]\n",
369+
" for _ in range(1000):\n",
370+
" output, hidden = model(torch.LongTensor([last_char]).to(device), hidden)\n",
371+
" \n",
372+
" distribution = output.squeeze().div(temperature).exp()\n",
373+
" guess = torch.multinomial(distribution, 1).item()\n",
374+
" \n",
375+
" last_char = guess\n",
376+
" text += dataset.idx2char[guess]\n",
377+
" \n",
378+
"pretty_print(text)"
379+
],
380+
"execution_count": 0,
381+
"outputs": []
382+
},
383+
{
384+
"cell_type": "code",
385+
"metadata": {
386+
"id": "l-RPn_4EaJzL",
387+
"colab_type": "code",
388+
"colab": {}
389+
},
390+
"source": [
391+
""
392+
],
393+
"execution_count": 0,
394+
"outputs": []
395+
}
396+
]
397+
}

‎images/char-rnn.png‎

34.3 KB
Loading[フレーム]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /