Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bbdbc20

Browse files
committed
cleaned classification_test.py a bit!
1 parent d0b5217 commit bbdbc20

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

‎ImageNet/training_scripts/imagenet_training/classification_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
1414
parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset')
1515
parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
16+
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
1617
parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before doing classification!')
1718
parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
1819
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
@@ -25,23 +26,22 @@
2526
model = create_model(
2627
args.model,
2728
num_classes=args.num_classes,
29+
pretrained=args.pretrained,
2830
checkpoint_path=args.weights,
2931
scale=args.netscale,
3032
network_idx = args.netidx,
3133
mode = args.netmode,
3234
)
35+
model.eval()
3336

34-
# print('Restoring model state from checkpoint...')
35-
# model_weights = torch.load(args.weights, map_location='cpu')
36-
# model.load_state_dict(model_weights)
37-
# model.eval()
37+
if not args.pretrained and not args.weights:
38+
print(f'WARNING: No pretrained weights specified! (pretrained is False and there is no checkpoint specified!)')
3839

3940
if args.jit:
4041
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
4142
model = torch.jit.trace(model, dummy_input)
4243

4344
config = resolve_data_config({}, model=model)
44-
print(f'config: {config}')
4545
transform = create_transform(**config)
4646

4747
filename = "./misc_files/dog.jpg"
@@ -53,13 +53,14 @@
5353
with torch.no_grad():
5454
out = model(tensor)
5555
probabilities = torch.nn.functional.softmax(out[0], dim=0)
56-
print(probabilities.shape) # prints: torch.Size([1000])
56+
print(f'{probabilities.shape}') # prints: torch.Size([1000])
5757

5858
filename="./misc_files/imagenet_classes.txt"
5959
with open(filename, "r") as f:
6060
categories = [s.strip() for s in f.readlines()]
6161

6262
# Print top categories per image
63+
print(f'Top categories:')
6364
top5_prob, top5_catid = torch.topk(probabilities, 5)
6465
for i in range(top5_prob.size(0)):
6566
print(categories[top5_catid[i]], top5_prob[i].item())

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /