Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3db3ea8

Browse files
Merge pull request GokuMohandas#61 from GokuMohandas/cv
🔥notebooks refactored for PyTorch 1.0
2 parents 4df675b + b00c435 commit 3db3ea8

6 files changed

+1147
-861
lines changed

‎notebooks/10_Object_Oriented_ML.ipynb‎

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
" torch.cuda.manual_seed_all(seed)\n",
160160
" \n",
161161
"# Creating directories\n",
162-
"def handle_dirs(dirpath):\n",
162+
"def create_dirs(dirpath):\n",
163163
" if not os.path.exists(dirpath):\n",
164164
" os.makedirs(dirpath)"
165165
],
@@ -188,7 +188,6 @@
188188
" vectorizer_file=\"vectorizer.json\",\n",
189189
" model_state_file=\"model.pth\",\n",
190190
" save_dir=\"names\",\n",
191-
" reload_from_files=False,\n",
192191
" train_size=0.7,\n",
193192
" val_size=0.15,\n",
194193
" test_size=0.15,\n",
@@ -204,7 +203,7 @@
204203
"set_seeds(seed=args.seed, cuda=args.cuda)\n",
205204
"\n",
206205
"# Create save dir\n",
207-
"handle_dirs(args.save_dir)\n",
206+
"create_dirs(args.save_dir)\n",
208207
"\n",
209208
"# Expand filepaths\n",
210209
"args.vectorizer_file = os.path.join(args.save_dir, args.vectorizer_file)\n",
@@ -705,8 +704,9 @@
705704
" nationality_vocab.add_token(row.nationality)\n",
706705
"print (nationality_vocab) # __str__\n",
707706
"print (len(nationality_vocab)) # __len__\n",
708-
"print (nationality_vocab.lookup_token(\"English\"))\n",
709-
"print (nationality_vocab.lookup_index(0))"
707+
"index = nationality_vocab.lookup_token(\"English\")\n",
708+
"print (index)\n",
709+
"print (nationality_vocab.lookup_index(index))"
710710
],
711711
"execution_count": 0,
712712
"outputs": [
@@ -1126,7 +1126,7 @@
11261126
" \n",
11271127
" # Iterate over train dataset\n",
11281128
"\n",
1129-
" # setup: batch generator, set loss and acc to 0, set train mode on\n",
1129+
" # initialize batch generator, set loss and acc to 0, set train mode on\n",
11301130
" self.dataset.set_split('train')\n",
11311131
" batch_generator = self.dataset.generate_batches(\n",
11321132
" batch_size=self.batch_size, shuffle=self.shuffle, \n",
@@ -1136,26 +1136,23 @@
11361136
" self.model.train()\n",
11371137
"\n",
11381138
" for batch_index, batch_dict in enumerate(batch_generator):\n",
1139-
" # the training routine is these 5 steps:\n",
1140-
"\n",
1141-
" # --------------------------------------\n",
1142-
" # step 1. zero the gradients\n",
1139+
" # zero the gradients\n",
11431140
" self.optimizer.zero_grad()\n",
11441141
"\n",
1145-
" # step 2. compute the output\n",
1142+
" # compute the output\n",
11461143
" y_pred = self.model(batch_dict['surname'])\n",
11471144
"\n",
1148-
" # step 3. compute the loss\n",
1145+
" # compute the loss\n",
11491146
" loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
11501147
" loss_t = loss.item()\n",
11511148
" running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
11521149
"\n",
1153-
" # step 4. use loss to produce gradients\n",
1150+
" # compute gradients using loss\n",
11541151
" loss.backward()\n",
11551152
"\n",
1156-
" # step 5. use optimizer to take gradient step\n",
1153+
" # use optimizer to take a gradient step\n",
11571154
" self.optimizer.step()\n",
1158-
" # -----------------------------------------\n",
1155+
" \n",
11591156
" # compute the accuracy\n",
11601157
" acc_t = self.compute_accuracy(y_pred, batch_dict['nationality'])\n",
11611158
" running_acc += (acc_t - running_acc) / (batch_index + 1)\n",
@@ -1165,7 +1162,7 @@
11651162
"\n",
11661163
" # Iterate over val dataset\n",
11671164
"\n",
1168-
" # setup: batch generator, set loss and acc to 0; set eval mode on\n",
1165+
" # initialize batch generator, set loss and acc to 0; set eval mode on\n",
11691166
" self.dataset.set_split('val')\n",
11701167
" batch_generator = self.dataset.generate_batches(\n",
11711168
" batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
@@ -1178,7 +1175,7 @@
11781175
" # compute the output\n",
11791176
" y_pred = self.model(batch_dict['surname'])\n",
11801177
"\n",
1181-
" # step 3. compute the loss\n",
1178+
" # compute the loss\n",
11821179
" loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
11831180
" loss_t = loss.to(\"cpu\").item()\n",
11841181
" running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
@@ -1196,6 +1193,7 @@
11961193
" break\n",
11971194
" \n",
11981195
" def run_test_loop(self):\n",
1196+
" # initialize batch generator, set loss and acc to 0; set eval mode on\n",
11991197
" self.dataset.set_split('test')\n",
12001198
" batch_generator = self.dataset.generate_batches(\n",
12011199
" batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
@@ -1263,14 +1261,8 @@
12631261
"cell_type": "code",
12641262
"source": [
12651263
"# Initialization\n",
1266-
"if args.reload_from_files:\n",
1267-
" print (\"Reloading!\")\n",
1268-
" dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n",
1269-
" args.split_data_file,args.vectorizer_file)\n",
1270-
"else:\n",
1271-
" print (\"Creating from scratch!\")\n",
1272-
" dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n",
1273-
" dataset.save_vectorizer(args.vectorizer_file)\n",
1264+
"dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n",
1265+
"dataset.save_vectorizer(args.vectorizer_file)\n",
12741266
"vectorizer = dataset.vectorizer\n",
12751267
"model = SurnameModel(input_dim=len(vectorizer.surname_vocab), \n",
12761268
" hidden_dim=args.hidden_dim, \n",
@@ -1492,7 +1484,6 @@
14921484
"cell_type": "code",
14931485
"source": [
14941486
"# Load the model\n",
1495-
"print (\"Reloading!\")\n",
14961487
"dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n",
14971488
" args.split_data_file,args.vectorizer_file)\n",
14981489
"vectorizer = dataset.vectorizer\n",

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /