-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathbenchmark_test_image.py
134 lines (113 loc) · 4.56 KB
/
benchmark_test_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from argparse import ArgumentParser
import mmcv
from mmengine.config import Config
from mmengine.logging import MMLogger
from mmengine.utils import mkdir_or_exist
from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.utils import register_all_modules
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
parser.add_argument('--aug', action='store_true', help='aug test')
parser.add_argument('--model-name', help='model name to inference')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument('--out-dir', default=None, help='Dir to output file')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='the interval of show (s), 0 is block')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args
def inference_model(config_name, checkpoint, visualizer, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
raise NotImplementedError()
model = init_detector(
cfg, checkpoint, palette=args.palette, device=args.device)
visualizer.dataset_meta = model.dataset_meta
# test a single image
result = inference_detector(model, args.img)
# show the results
if args.show or args.out_dir is not None:
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None
if args.out_dir is not None:
out_dir = args.out_dir
mkdir_or_exist(out_dir)
out_file = osp.join(
out_dir,
config_name.split('/')[-1].replace('py', 'jpg'))
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=args.wait_time,
out_file=out_file,
pred_score_thr=args.score_thr)
return result
# Sample test whether the inference code is correct
def main(args):
# register all modules in mmdet into the registries
register_all_modules()
config = Config.fromfile(args.config)
# init visualizer
visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer')
visualizer = VISUALIZERS.build(visualizer_cfg)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
model_info = model_infos[0]
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, visualizer, args)
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = MMLogger.get_instance(
name='MMLogger',
log_file='benchmark_test_image.log',
log_level=logging.ERROR)
for model_key in config:
model_infos = config[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_name = model_info['config'].strip()
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, visualizer, args,
logger)
except Exception as e:
logger.error(f'{config_name} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)