|
159 | 159 | " torch.cuda.manual_seed_all(seed)\n",
|
160 | 160 | " \n",
|
161 | 161 | "# Creating directories\n",
|
162 | | - "def handle_dirs(dirpath):\n", |
| 162 | + "def create_dirs(dirpath):\n", |
163 | 163 | " if not os.path.exists(dirpath):\n",
|
164 | 164 | " os.makedirs(dirpath)"
|
165 | 165 | ],
|
|
188 | 188 | " vectorizer_file=\"vectorizer.json\",\n",
|
189 | 189 | " model_state_file=\"model.pth\",\n",
|
190 | 190 | " save_dir=\"names\",\n",
|
191 | | - " reload_from_files=False,\n", |
192 | 191 | " train_size=0.7,\n",
|
193 | 192 | " val_size=0.15,\n",
|
194 | 193 | " test_size=0.15,\n",
|
|
204 | 203 | "set_seeds(seed=args.seed, cuda=args.cuda)\n",
|
205 | 204 | "\n",
|
206 | 205 | "# Create save dir\n",
|
207 | | - "handle_dirs(args.save_dir)\n", |
| 206 | + "create_dirs(args.save_dir)\n", |
208 | 207 | "\n",
|
209 | 208 | "# Expand filepaths\n",
|
210 | 209 | "args.vectorizer_file = os.path.join(args.save_dir, args.vectorizer_file)\n",
|
|
705 | 704 | " nationality_vocab.add_token(row.nationality)\n",
|
706 | 705 | "print (nationality_vocab) # __str__\n",
|
707 | 706 | "print (len(nationality_vocab)) # __len__\n",
|
708 | | - "print (nationality_vocab.lookup_token(\"English\"))\n", |
709 | | - "print (nationality_vocab.lookup_index(0))" |
| 707 | + "index = nationality_vocab.lookup_token(\"English\")\n", |
| 708 | + "print (index)\n", |
| 709 | + "print (nationality_vocab.lookup_index(index))" |
710 | 710 | ],
|
711 | 711 | "execution_count": 0,
|
712 | 712 | "outputs": [
|
|
1126 | 1126 | " \n",
|
1127 | 1127 | " # Iterate over train dataset\n",
|
1128 | 1128 | "\n",
|
1129 | | - " # setup: batch generator, set loss and acc to 0, set train mode on\n", |
| 1129 | + " # initialize batch generator, set loss and acc to 0, set train mode on\n", |
1130 | 1130 | " self.dataset.set_split('train')\n",
|
1131 | 1131 | " batch_generator = self.dataset.generate_batches(\n",
|
1132 | 1132 | " batch_size=self.batch_size, shuffle=self.shuffle, \n",
|
|
1136 | 1136 | " self.model.train()\n",
|
1137 | 1137 | "\n",
|
1138 | 1138 | " for batch_index, batch_dict in enumerate(batch_generator):\n",
|
1139 | | - " # the training routine is these 5 steps:\n", |
1140 | | - "\n", |
1141 | | - " # --------------------------------------\n", |
1142 | | - " # step 1. zero the gradients\n", |
| 1139 | + " # zero the gradients\n", |
1143 | 1140 | " self.optimizer.zero_grad()\n",
|
1144 | 1141 | "\n",
|
1145 | | - " # step 2. compute the output\n", |
| 1142 | + " # compute the output\n", |
1146 | 1143 | " y_pred = self.model(batch_dict['surname'])\n",
|
1147 | 1144 | "\n",
|
1148 | | - " # step 3. compute the loss\n", |
| 1145 | + " # compute the loss\n", |
1149 | 1146 | " loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
|
1150 | 1147 | " loss_t = loss.item()\n",
|
1151 | 1148 | " running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
|
1152 | 1149 | "\n",
|
1153 | | - " # step 4. use loss to produce gradients\n", |
| 1150 | + " # compute gradients using loss\n", |
1154 | 1151 | " loss.backward()\n",
|
1155 | 1152 | "\n",
|
1156 | | - " # step 5. use optimizer to take gradient step\n", |
| 1153 | + " # use optimizer to take a gradient step\n", |
1157 | 1154 | " self.optimizer.step()\n",
|
1158 | | - " # -----------------------------------------\n", |
| 1155 | + " \n", |
1159 | 1156 | " # compute the accuracy\n",
|
1160 | 1157 | " acc_t = self.compute_accuracy(y_pred, batch_dict['nationality'])\n",
|
1161 | 1158 | " running_acc += (acc_t - running_acc) / (batch_index + 1)\n",
|
|
1165 | 1162 | "\n",
|
1166 | 1163 | " # Iterate over val dataset\n",
|
1167 | 1164 | "\n",
|
1168 | | - " # setup: batch generator, set loss and acc to 0; set eval mode on\n", |
| 1165 | + " # initialize batch generator, set loss and acc to 0; set eval mode on\n", |
1169 | 1166 | " self.dataset.set_split('val')\n",
|
1170 | 1167 | " batch_generator = self.dataset.generate_batches(\n",
|
1171 | 1168 | " batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
|
|
1178 | 1175 | " # compute the output\n",
|
1179 | 1176 | " y_pred = self.model(batch_dict['surname'])\n",
|
1180 | 1177 | "\n",
|
1181 | | - " # step 3. compute the loss\n", |
| 1178 | + " # compute the loss\n", |
1182 | 1179 | " loss = self.loss_func(y_pred, batch_dict['nationality'])\n",
|
1183 | 1180 | " loss_t = loss.to(\"cpu\").item()\n",
|
1184 | 1181 | " running_loss += (loss_t - running_loss) / (batch_index + 1)\n",
|
|
1196 | 1193 | " break\n",
|
1197 | 1194 | " \n",
|
1198 | 1195 | " def run_test_loop(self):\n",
|
| 1196 | + " # initialize batch generator, set loss and acc to 0; set eval mode on\n", |
1199 | 1197 | " self.dataset.set_split('test')\n",
|
1200 | 1198 | " batch_generator = self.dataset.generate_batches(\n",
|
1201 | 1199 | " batch_size=self.batch_size, shuffle=self.shuffle, device=self.device)\n",
|
|
1263 | 1261 | "cell_type": "code",
|
1264 | 1262 | "source": [
|
1265 | 1263 | "# Initialization\n",
|
1266 | | - "if args.reload_from_files:\n", |
1267 | | - " print (\"Reloading!\")\n", |
1268 | | - " dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n", |
1269 | | - " args.split_data_file,args.vectorizer_file)\n", |
1270 | | - "else:\n", |
1271 | | - " print (\"Creating from scratch!\")\n", |
1272 | | - " dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n", |
1273 | | - " dataset.save_vectorizer(args.vectorizer_file)\n", |
| 1264 | + "dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.split_data_file)\n", |
| 1265 | + "dataset.save_vectorizer(args.vectorizer_file)\n", |
1274 | 1266 | "vectorizer = dataset.vectorizer\n",
|
1275 | 1267 | "model = SurnameModel(input_dim=len(vectorizer.surname_vocab), \n",
|
1276 | 1268 | " hidden_dim=args.hidden_dim, \n",
|
|
1492 | 1484 | "cell_type": "code",
|
1493 | 1485 | "source": [
|
1494 | 1486 | "# Load the model\n",
|
1495 | | - "print (\"Reloading!\")\n", |
1496 | 1487 | "dataset = SurnameDataset.load_dataset_and_load_vectorizer(\n",
|
1497 | 1488 | " args.split_data_file,args.vectorizer_file)\n",
|
1498 | 1489 | "vectorizer = dataset.vectorizer\n",
|
|
0 commit comments