Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 61f7bcd

Browse files
committed
update
1 parent af86648 commit 61f7bcd

File tree

1 file changed

+20
-33
lines changed

1 file changed

+20
-33
lines changed

‎RektNet/keypoints_tutorial.ipynb‎

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@
161161
"metadata": {},
162162
"outputs": [],
163163
"source": [
164-
"\n",
165164
"study_name=\"tutorial\"\n",
166165
"\n",
167166
"current_month = datetime.now().strftime('%B').lower()\n",
@@ -174,28 +173,25 @@
174173
"sys.stdout = Logger(save_file_name + '.log')\n",
175174
"sys.stderr = Logger(save_file_name + '.error')\n",
176175
"\n",
177-
"INPUT_SIZE = (80, 80)\n",
178-
"KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"]\n",
179-
"\n",
176+
"# Training related config\n",
177+
"INPUT_SIZE = (80, 80) # dataset size\n",
178+
"KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"] # set up geometry loss keys\n",
180179
"intervals = int(2) # for normal training, set it to 4\n",
181-
"val_split = float(0.15)\n",
182-
"\n",
183-
"batch_size= int(32)\n",
180+
"val_split = float(0.15) # training validation split ratio\n",
181+
"batch_size= int(8)\n",
184182
"num_epochs= int(4) # for normal training, set it to 1024\n",
185-
"\n",
186-
"# Load the train data.\n",
187183
"train_csv = \"dataset/rektnet_label.csv\"\n",
188184
"dataset_path = \"dataset/RektNet_Dataset/\"\n",
189-
"vis_dataloader = False\n",
185+
"vis_dataloader = False # visualize dataset\n",
190186
"save_checkpoints = True\n",
191-
"save_checkpoints=True\n",
192-
"evaluate_mode=False\n",
187+
"\n",
188+
"# Training related hyperparameter\n",
193189
"lr = 1e-1\n",
194190
"lr_gamma = 0.999\n",
195191
"geo_loss = True\n",
196192
"geo_loss_gamma_vert = 0\n",
197193
"geo_loss_gamma_horz = 0\n",
198-
"loss_type = \"l1_softargmax\"\n",
194+
"loss_type = \"l1_softargmax\" # loss function type: l2_softargmax|l2_heatmap|l1_softargmax\n",
199195
"best_val_loss = float('inf')\n",
200196
"best_epoch = 0\n",
201197
"max_tolerance = 8\n",
@@ -247,6 +243,13 @@
247243
"loss_func = CrossRatioLoss(loss_type, geo_loss, geo_loss_gamma_horz, geo_loss_gamma_vert)"
248244
]
249245
},
246+
{
247+
"cell_type": "markdown",
248+
"metadata": {},
249+
"source": [
250+
"## Training"
251+
]
252+
},
250253
{
251254
"cell_type": "code",
252255
"execution_count": null,
@@ -439,26 +442,10 @@
439442
"image = cv2.imread(image_filepath)\n",
440443
"h, w, _ = image.shape\n",
441444
"\n",
442-
"image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)"
443-
]
444-
},
445-
{
446-
"cell_type": "code",
447-
"execution_count": null,
448-
"metadata": {},
449-
"outputs": [],
450-
"source": [
451-
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
452-
]
453-
},
454-
{
455-
"cell_type": "code",
456-
"execution_count": null,
457-
"metadata": {
458-
"scrolled": true
459-
},
460-
"outputs": [],
461-
"source": [
445+
"image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)\n",
446+
"\n",
447+
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
448+
"\n",
462449
"pt.fig = pt.figure(figsize=(5, 5))\n",
463450
"\n",
464451
"pt.imshow(image)\n",

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /