Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b00c435

Browse files
notebooks refactored for pytorch 1.0
1 parent a4582ba commit b00c435

6 files changed

+1147
-861
lines changed

‎notebooks/10_Object_Oriented_ML.ipynb‎

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
" torch.cuda.manual_seed_all(seed)\n",
160160
" \n",
161161
"# Creating directories\n",
162-
"def handle_dirs(dirpath):\n",
162+
"def create_dirs(dirpath):\n",
163163
" if not os.path.exists(dirpath):\n",
164164
" os.makedirs(dirpath)"
165165
],
@@ -188,7 +188,6 @@
188188
" vectorizer_file=\"vectorizer.json\",\n",
189189
" model_state_file=\"model.pth\",\n",
190190
" save_dir=\"names\",\n",
191-
" reload_from_files=False,\n",
192191
" train_size=0.7,\n",
193192
" val_size=0.15,\n",
194193
" test_size=0.15,\n",
@@ -204,7 +203,7 @@
204203
"set_seeds(seed=args.seed, cuda=args.cuda)\n",
205204
"\n",
206205
"# Create save dir\n",
207-
"handle_dirs(args.save_dir)\n",
206+
"create_dirs(args.save_dir)\n",
208207
"\n",
209208
"# Expand filepaths\n",
210209
"args.vectorizer_file = os.path.join(args.save_dir, args.vectorizer_file)\n",
@@ -705,8 +704,9 @@
705704
" nationality_vocab.add_token(row.nationality)\n",
706705
"print (nationality_vocab) # __str__\n",
707706
"print (len(nationality_vocab)) # __len__\n",
708-
"print (nationality_vocab.lookup_token(\"English\"))\n",
709-
"print (nationality_vocab.lookup_index(0))"
707+
"index = nationality_vocab.lookup_token(\"English\")\n",
708+
"print (index)\n",
709+
"print (nationality_vocab.lookup_index(index))"
710710
],
711711
"execution_count": 0,
712712
"outputs": [
@@ -1126,7 +1126,7 @@
11261126
" \n",
11271127
" # Iterate over train dataset\n",
11281128
"\n",
1129-
" # setup: batch generator, set loss and acc to 0, set train mode on\n",
1129+
" # initialize batch generator, set loss and acc to 0, set train mode on\n",
11301130
" self.dataset.set_split('train')\n",
11311131
" batch_generator = self.dataset.generate_batches(\n",
11321132
" batch_size=self.batch_size, shuffle=self.shuffle, \n",
@@ -1136,26 +1136,23 @@
11361136
" self.model.train()\n",
11371137
"\n",
11381138
" for batch_index, batch_dict in enumerate(batch_generator):\n",
1139-
" # the training routine is these 5 steps:\n",
1140-
"\n",
1141-
" # --------------------------------------\n",
1142-
" # step 1. zero the gradients\n",
1139+
" # zero the gradients\n",
11431140
" self.optimizer.zero_grad()\n",
11441141
"\n",
1145-
" # step 2. compute the output\n",
1142+
" # compute the output\n",
11461143
" y_pred = self.model(batch_dict['surname'])\n",
11471144
"\n",
1148-
" # step 3. compute the loss\n",
1145+
" # compute the loss\n",
11491146
" loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
11501147
" loss_t = loss.item()\n",
11511148
" running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
11521149
"\n",
1153-
" # step 4. use loss to produce gradients\n",
1150+
" # compute gradients using loss\n",
11541151
" loss.backward()\n",
11551152
"\n",
1156-
" # step 5. use optimizer to take gradient step\n",
1153+
" # use optimizer to take a gradient step\n",
11571154
" self.optimizer.step()\n",
1158-
" # -----------------------------------------\n",
1155+
" \n",
11591156
" # compute the accuracy\n",
11601157
" acc_t = self.compute_accuracy(y_pred, batch_dict['nationality'])\n",
11611158
" running_acc += (acc_t - running_acc) / (batch_index + 1)\n",
@@ -1165,7 +1162,7 @@
11651162
"\n",
11661163
" # Iterate over val dataset\n",
11671164
"\n",
1168-
" # setup: batch generator, set loss and acc to 0; set eval mode on\n",
1165+
" # initialize batch generator, set loss and acc to 0; set eval mode on\n",
11691166
" self.dataset.set_split('val')\n",
11701167
" batch_generator = self.dataset.generate_batches(\n",
11711168
" batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
@@ -1178,7 +1175,7 @@
11781175
" # compute the output\n",
11791176
" y_pred = self.model(batch_dict['surname'])\n",
11801177
"\n",
1181-
" # step 3. compute the loss\n",
1178+
" # compute the loss\n",
11821179
" loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
11831180
" loss_t = loss.to(\"cpu\").item()\n",
11841181
" running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
@@ -1196,6 +1193,7 @@
11961193
" break\n",
11971194
" \n",
11981195
" def run_test_loop(self):\n",
1196+
" # initialize batch generator, set loss and acc to 0; set eval mode on\n",
11991197
" self.dataset.set_split('test')\n",
12001198
" batch_generator = self.dataset.generate_batches(\n",
12011199
" batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
@@ -1263,14 +1261,8 @@
12631261
"cell_type": "code",
12641262
"source": [
12651263
"# Initialization\n",
1266-
"if args.reload_from_files:\n",
1267-
" print (\"Reloading!\")\n",
1268-
" dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n",
1269-
" args.split_data_file,args.vectorizer_file)\n",
1270-
"else:\n",
1271-
" print (\"Creating from scratch!\")\n",
1272-
" dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n",
1273-
" dataset.save_vectorizer(args.vectorizer_file)\n",
1264+
"dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n",
1265+
"dataset.save_vectorizer(args.vectorizer_file)\n",
12741266
"vectorizer = dataset.vectorizer\n",
12751267
"model = SurnameModel(input_dim=len(vectorizer.surname_vocab), \n",
12761268
" hidden_dim=args.hidden_dim, \n",
@@ -1492,7 +1484,6 @@
14921484
"cell_type": "code",
14931485
"source": [
14941486
"# Load the model\n",
1495-
"print (\"Reloading!\")\n",
14961487
"dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n",
14971488
" args.split_data_file,args.vectorizer_file)\n",
14981489
"vectorizer = dataset.vectorizer\n",

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /