|
161 | 161 | "metadata": {}, |
162 | 162 | "outputs": [], |
163 | 163 | "source": [ |
164 | | - "\n", |
165 | 164 | "study_name=\"tutorial\"\n", |
166 | 165 | "\n", |
167 | 166 | "current_month = datetime.now().strftime('%B').lower()\n", |
|
174 | 173 | "sys.stdout = Logger(save_file_name + '.log')\n", |
175 | 174 | "sys.stderr = Logger(save_file_name + '.error')\n", |
176 | 175 | "\n", |
177 | | - "INPUT_SIZE = (80, 80)\n", |
178 | | - "KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"]\n", |
179 | | - "\n", |
| 176 | + "# Training related config\n", |
| 177 | + "INPUT_SIZE = (80, 80) # dataset size\n", |
| 178 | + "KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"] # set up geometry loss keys\n", |
180 | 179 | "intervals = int(2) # for normal training, set it to 4\n", |
181 | | - "val_split = float(0.15)\n", |
182 | | - "\n", |
183 | | - "batch_size= int(32)\n", |
| 180 | + "val_split = float(0.15) # training validation split ratio\n", |
| 181 | + "batch_size= int(8)\n", |
184 | 182 | "num_epochs= int(4) # for normal training, set it to 1024\n", |
185 | | - "\n", |
186 | | - "# Load the train data.\n", |
187 | 183 | "train_csv = \"dataset/rektnet_label.csv\"\n", |
188 | 184 | "dataset_path = \"dataset/RektNet_Dataset/\"\n", |
189 | | - "vis_dataloader = False\n", |
| 185 | + "vis_dataloader = False # visualize dataset\n", |
190 | 186 | "save_checkpoints = True\n", |
191 | | - "save_checkpoints=True\n", |
192 | | - "evaluate_mode=False\n", |
| 187 | + "\n", |
| 188 | + "# Training related hyperparameter\n", |
193 | 189 | "lr = 1e-1\n", |
194 | 190 | "lr_gamma = 0.999\n", |
195 | 191 | "geo_loss = True\n", |
196 | 192 | "geo_loss_gamma_vert = 0\n", |
197 | 193 | "geo_loss_gamma_horz = 0\n", |
198 | | - "loss_type = \"l1_softargmax\"\n", |
| 194 | + "loss_type = \"l1_softargmax\" # loss function type: l2_softargmax|l2_heatmap|l1_softargmax\n", |
199 | 195 | "best_val_loss = float('inf')\n", |
200 | 196 | "best_epoch = 0\n", |
201 | 197 | "max_tolerance = 8\n", |
|
247 | 243 | "loss_func = CrossRatioLoss(loss_type, geo_loss, geo_loss_gamma_horz, geo_loss_gamma_vert)" |
248 | 244 | ] |
249 | 245 | }, |
| 246 | + { |
| 247 | + "cell_type": "markdown", |
| 248 | + "metadata": {}, |
| 249 | + "source": [ |
| 250 | + "## Training" |
| 251 | + ] |
| 252 | + }, |
250 | 253 | { |
251 | 254 | "cell_type": "code", |
252 | 255 | "execution_count": null, |
|
439 | 442 | "image = cv2.imread(image_filepath)\n", |
440 | 443 | "h, w, _ = image.shape\n", |
441 | 444 | "\n", |
442 | | - "image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)" |
443 | | - ] |
444 | | - }, |
445 | | - { |
446 | | - "cell_type": "code", |
447 | | - "execution_count": null, |
448 | | - "metadata": {}, |
449 | | - "outputs": [], |
450 | | - "source": [ |
451 | | - "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)" |
452 | | - ] |
453 | | - }, |
454 | | - { |
455 | | - "cell_type": "code", |
456 | | - "execution_count": null, |
457 | | - "metadata": { |
458 | | - "scrolled": true |
459 | | - }, |
460 | | - "outputs": [], |
461 | | - "source": [ |
| 445 | + "image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)\n", |
| 446 | + "\n", |
| 447 | + "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", |
| 448 | + "\n", |
462 | 449 | "pt.fig = pt.figure(figsize=(5, 5))\n", |
463 | 450 | "\n", |
464 | 451 | "pt.imshow(image)\n", |
|
0 commit comments