1
1
#in the name of God the most compassionate the most merciful
2
2
# conver pytorch model to onnx models
3
+ import os
3
4
import argparse
4
5
import numpy as np
5
6
15
16
parser .add_argument ('--weights' , default = '' , type = str , metavar = 'PATH' , help = 'path to model weights (default: none)' )
16
17
parser .add_argument ('--output' , default = 'simpnet.onnx' , type = str , metavar = 'FILENAME' , help = 'Output model file (.onnx model)' )
17
18
# parser.add_argument('--opset', default=0, type=int, help='opset version (default:0) valid values, 0 to 10')
19
+ parser .add_argument ('--use_input_dir' , action = 'store_true' , default = False , help = 'save in the same directory as input' )
18
20
parser .add_argument ('--jit' , action = 'store_true' , default = False , help = 'convert the model to jit before conversion to onnx' )
19
21
parser .add_argument ('--netscale' , type = float , default = 1.0 , help = 'scale of the net (default 1.0)' )
20
22
parser .add_argument ('--netidx' , type = int , default = 0 , help = 'which network to use (5mil or 8mil)' )
36
38
model .eval ()
37
39
38
40
dummy_input = torch .randn (1 , 3 , 224 , 224 , device = "cpu" )
41
+
42
+ new_output_name = args .output
43
+ if args .use_input_dir :
44
+ base_name = os .path .basename (args .weights )
45
+ dir = args .weights .replace (base_name ,'' )
46
+ new_output_name = os .path .join (dir ,base_name .replace ('.pth' ,'.onnx' ))
39
47
40
48
if args .jit :
41
49
model = torch .jit .trace (model , dummy_input )
42
- model .save (f"{ args . output .replace ('.onnx' ,'-jit' )} .pt" )
50
+ model .save (f"{ new_output_name .replace ('.onnx' ,'-jit' )} .pt" )
43
51
44
52
input_names = ["data" ]
45
53
output_names = ["pred" ]
46
54
# for caffe conversion its must be 9.
47
- torch .onnx .export (model , dummy_input , args .output , opset_version = 9 , verbose = True , input_names = input_names , output_names = output_names )
55
+ #! train mode crashes for some reason, need to report the bug.
56
+ torch .onnx .export (model , dummy_input , new_output_name , opset_version = 9 , verbose = True , input_names = input_names , output_names = output_names )
48
57
49
58
print (f'Converted successfully to onnx.' )
50
59
print ('Testing the new onnx model...' )
51
60
# Load the ONNX model
52
- model_onnx = onnx .load (args . output )
61
+ model_onnx = onnx .load (new_output_name )
53
62
# Check that the model is well formed
54
63
onnx .checker .check_model (model_onnx )
55
64
# Print a human readable representation of the graph
@@ -61,7 +70,7 @@ def to_numpy(tensor):
61
70
# pytorch model output
62
71
torch_out = model (dummy_input )
63
72
# onnx model output
64
- ort_session = onnxruntime .InferenceSession (args . output )
73
+ ort_session = onnxruntime .InferenceSession (new_output_name )
65
74
# compute ONNX Runtime output prediction
66
75
ort_inputs = {ort_session .get_inputs ()[0 ].name : to_numpy (dummy_input )}
67
76
ort_outs = ort_session .run (None , ort_inputs )
0 commit comments