Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 15af4fc

Browse files
committed
update convert_to_onnx.py
1 parent 468e302 commit 15af4fc

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

‎ImageNet/training_scripts/imagenet_training/convert_to_onnx.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#in the name of God the most compassionate the most merciful
22
# conver pytorch model to onnx models
3+
import os
34
import argparse
45
import numpy as np
56

@@ -15,6 +16,7 @@
1516
parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
1617
parser.add_argument('--output', default='simpnet.onnx', type=str, metavar='FILENAME', help='Output model file (.onnx model)')
1718
# parser.add_argument('--opset', default=0, type=int, help='opset version (default:0) valid values, 0 to 10')
19+
parser.add_argument('--use_input_dir', action='store_true', default=False, help='save in the same directory as input')
1820
parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before conversion to onnx')
1921
parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
2022
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
@@ -36,20 +38,27 @@
3638
model.eval()
3739

3840
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
41+
42+
new_output_name = args.output
43+
if args.use_input_dir:
44+
base_name = os.path.basename(args.weights)
45+
dir = args.weights.replace(base_name,'')
46+
new_output_name = os.path.join(dir,base_name.replace('.pth','.onnx'))
3947

4048
if args.jit:
4149
model = torch.jit.trace(model, dummy_input)
42-
model.save(f"{args.output.replace('.onnx','-jit')}.pt")
50+
model.save(f"{new_output_name.replace('.onnx','-jit')}.pt")
4351

4452
input_names = ["data"]
4553
output_names = ["pred"]
4654
# for caffe conversion its must be 9.
47-
torch.onnx.export(model, dummy_input, args.output, opset_version=9, verbose=True, input_names=input_names, output_names=output_names)
55+
#! train mode crashes for some reason, need to report the bug.
56+
torch.onnx.export(model, dummy_input, new_output_name, opset_version=9, verbose=True, input_names=input_names, output_names=output_names)
4857

4958
print(f'Converted successfully to onnx.')
5059
print('Testing the new onnx model...')
5160
# Load the ONNX model
52-
model_onnx = onnx.load(args.output)
61+
model_onnx = onnx.load(new_output_name)
5362
# Check that the model is well formed
5463
onnx.checker.check_model(model_onnx)
5564
# Print a human readable representation of the graph
@@ -61,7 +70,7 @@ def to_numpy(tensor):
6170
# pytorch model output
6271
torch_out = model(dummy_input)
6372
# onnx model output
64-
ort_session = onnxruntime.InferenceSession(args.output)
73+
ort_session = onnxruntime.InferenceSession(new_output_name)
6574
# compute ONNX Runtime output prediction
6675
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
6776
ort_outs = ort_session.run(None, ort_inputs)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /