|
7 | 7 | from timm.models import create_model
|
8 | 8 | from timm.data import resolve_data_config
|
9 | 9 | from timm.data.transforms_factory import create_transform
|
10 | | - |
| 10 | +importtorchvision |
11 | 11 |
|
12 | 12 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
13 | 13 | parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
|
|
31 | 31 | mode = args.netmode,
|
32 | 32 | )
|
33 | 33 |
|
34 | | -print('Restoring model state from checkpoint...') |
35 | | -model_weights = torch.load(args.weights, map_location='cpu') |
36 | | -model.load_state_dict(model_weights) |
37 | | -model.eval() |
| 34 | +# print('Restoring model state from checkpoint...') |
| 35 | +# model_weights = torch.load(args.weights, map_location='cpu') |
| 36 | +# model.load_state_dict(model_weights) |
| 37 | +# model.eval() |
38 | 38 |
|
39 | 39 | if args.jit:
|
40 | 40 | dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
|
|
47 | 47 | filename = "./misc_files/dog.jpg"
|
48 | 48 | img = Image.open(filename).convert('RGB')
|
49 | 49 | tensor = transform(img).unsqueeze(0)
|
| 50 | +# save the transformed image for visualization or testing the ported models |
| 51 | +torchvision.utils.save_image(tensor.squeeze(0),'img_test_transformed.jpg') |
50 | 52 |
|
51 | 53 | with torch.no_grad():
|
52 | 54 | out = model(tensor)
|
|
0 commit comments