Skip to content

Commit 15af4fc

Browse files
committed
update convert_to_onnx.py
1 parent 468e302 commit 15af4fc

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

ImageNet/training_scripts/imagenet_training/convert_to_onnx.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#in the name of God the most compassionate the most merciful
22
# conver pytorch model to onnx models
3+
import os
34
import argparse
45
import numpy as np
56

@@ -15,6 +16,7 @@
1516
parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
1617
parser.add_argument('--output', default='simpnet.onnx', type=str, metavar='FILENAME', help='Output model file (.onnx model)')
1718
# parser.add_argument('--opset', default=0, type=int, help='opset version (default:0) valid values, 0 to 10')
19+
parser.add_argument('--use_input_dir', action='store_true', default=False, help='save in the same directory as input')
1820
parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before conversion to onnx')
1921
parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
2022
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
@@ -36,20 +38,27 @@
3638
model.eval()
3739

3840
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
41+
42+
new_output_name = args.output
43+
if args.use_input_dir:
44+
base_name = os.path.basename(args.weights)
45+
dir = args.weights.replace(base_name,'')
46+
new_output_name = os.path.join(dir,base_name.replace('.pth','.onnx'))
3947

4048
if args.jit:
4149
model = torch.jit.trace(model, dummy_input)
42-
model.save(f"{args.output.replace('.onnx','-jit')}.pt")
50+
model.save(f"{new_output_name.replace('.onnx','-jit')}.pt")
4351

4452
input_names = ["data"]
4553
output_names = ["pred"]
4654
# for caffe conversion its must be 9.
47-
torch.onnx.export(model, dummy_input, args.output, opset_version=9, verbose=True, input_names=input_names, output_names=output_names)
55+
#! train mode crashes for some reason, need to report the bug.
56+
torch.onnx.export(model, dummy_input, new_output_name, opset_version=9, verbose=True, input_names=input_names, output_names=output_names)
4857

4958
print(f'Converted successfully to onnx.')
5059
print('Testing the new onnx model...')
5160
# Load the ONNX model
52-
model_onnx = onnx.load(args.output)
61+
model_onnx = onnx.load(new_output_name)
5362
# Check that the model is well formed
5463
onnx.checker.check_model(model_onnx)
5564
# Print a human readable representation of the graph
@@ -61,7 +70,7 @@ def to_numpy(tensor):
6170
# pytorch model output
6271
torch_out = model(dummy_input)
6372
# onnx model output
64-
ort_session = onnxruntime.InferenceSession(args.output)
73+
ort_session = onnxruntime.InferenceSession(new_output_name)
6574
# compute ONNX Runtime output prediction
6675
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
6776
ort_outs = ort_session.run(None, ort_inputs)

0 commit comments

Comments
 (0)