Skip to content

Commit 5556579

Browse files
committed
Allow ssd anchor generator to specify scales and
add interpolated scale at a give aspect ratio. This change breaks the existing checkpoints because the anchor orders are different. This change should be merged along newly generated frozen ssd models.
1 parent 205c5e0 commit 5556579

File tree

2 files changed

+172
-93
lines changed

2 files changed

+172
-93
lines changed

research/object_detection/anchor_generators/multiple_grid_anchor_generator.py

Lines changed: 108 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator):
3838
def __init__(self,
3939
box_specs_list,
4040
base_anchor_size=None,
41+
anchor_strides=None,
42+
anchor_offsets=None,
4143
clip_window=None):
4244
"""Constructs a MultipleGridAnchorGenerator.
4345
@@ -58,7 +60,26 @@ def __init__(self,
5860
outside list having the same number of entries as feature_map_shape_list
5961
(which is passed in at generation time).
6062
base_anchor_size: base anchor size as [height, width]
61-
(length-2 float tensor, default=[256, 256]).
63+
(length-2 float tensor, default=[1.0, 1.0]).
64+
The height and width values are normalized to the
65+
minimum dimension of the input height and width, so that
66+
when the base anchor height equals the base anchor
67+
width, the resulting anchor is square even if the input
68+
image is not square.
69+
anchor_strides: list of pairs of strides in pixels (in y and x directions
70+
respectively). For example, setting anchor_strides=[(25, 25), (50, 50)]
71+
means that we want the anchors corresponding to the first layer to be
72+
strided by 25 pixels and those in the second layer to be strided by 50
73+
pixels in both y and x directions. If anchor_strides=None, they are set
74+
to be the reciprocal of the corresponding feature map shapes.
75+
anchor_offsets: list of pairs of offsets in pixels (in y and x directions
76+
respectively). The offset specifies where we want the center of the
77+
(0, 0)-th anchor to lie for each layer. For example, setting
78+
anchor_offsets=[(10, 10), (20, 20)]) means that we want the
79+
(0, 0)-th anchor of the first layer to lie at (10, 10) in pixel space
80+
and likewise that we want the (0, 0)-th anchor of the second layer to
81+
lie at (25, 25) in pixel space. If anchor_offsets=None, then they are
82+
set to be half of the corresponding anchor stride.
6283
clip_window: a tensor of shape [4] specifying a window to which all
6384
anchors should be clipped. If clip_window is None, then no clipping
6485
is performed.
@@ -76,6 +97,8 @@ def __init__(self,
7697
if base_anchor_size is None:
7798
base_anchor_size = tf.constant([256, 256], dtype=tf.float32)
7899
self._base_anchor_size = base_anchor_size
100+
self._anchor_strides = anchor_strides
101+
self._anchor_offsets = anchor_offsets
79102
if clip_window is not None and clip_window.get_shape().as_list() != [4]:
80103
raise ValueError('clip_window must either be None or a shape [4] tensor')
81104
self._clip_window = clip_window
@@ -90,6 +113,18 @@ def __init__(self,
90113
self._scales.append(scales)
91114
self._aspect_ratios.append(aspect_ratios)
92115

116+
for arg, arg_name in zip([self._anchor_strides, self._anchor_offsets],
117+
['anchor_strides', 'anchor_offsets']):
118+
if arg and not (isinstance(arg, list) and
119+
len(arg) == len(self._box_specs)):
120+
raise ValueError('%s must be a list with the same length '
121+
'as self._box_specs' % arg_name)
122+
if arg and not all([
123+
isinstance(list_item, tuple) and len(list_item) == 2
124+
for list_item in arg
125+
]):
126+
raise ValueError('%s must be a list of pairs.' % arg_name)
127+
93128
def name_scope(self):
94129
return 'MultipleGridAnchorGenerator'
95130

@@ -102,12 +137,7 @@ def num_anchors_per_location(self):
102137
"""
103138
return [len(box_specs) for box_specs in self._box_specs]
104139

105-
def _generate(self,
106-
feature_map_shape_list,
107-
im_height=1,
108-
im_width=1,
109-
anchor_strides=None,
110-
anchor_offsets=None):
140+
def _generate(self, feature_map_shape_list, im_height=1, im_width=1):
111141
"""Generates a collection of bounding boxes to be used as anchors.
112142
113143
The number of anchors generated for a single grid with shape MxM where we
@@ -133,25 +163,6 @@ def _generate(self,
133163
im_height and im_width are 1, the generated anchors default to
134164
normalized coordinates, otherwise absolute coordinates are used for the
135165
grid.
136-
anchor_strides: list of pairs of strides (in y and x directions
137-
respectively). For example, setting
138-
anchor_strides=[(.25, .25), (.5, .5)] means that we want the anchors
139-
corresponding to the first layer to be strided by .25 and those in the
140-
second layer to be strided by .5 in both y and x directions. By
141-
default, if anchor_strides=None, then they are set to be the reciprocal
142-
of the corresponding grid sizes. The pairs can also be specified as
143-
dynamic tf.int or tf.float numbers, e.g. for variable shape input
144-
images.
145-
anchor_offsets: list of pairs of offsets (in y and x directions
146-
respectively). The offset specifies where we want the center of the
147-
(0, 0)-th anchor to lie for each layer. For example, setting
148-
anchor_offsets=[(.125, .125), (.25, .25)]) means that we want the
149-
(0, 0)-th anchor of the first layer to lie at (.125, .125) in image
150-
space and likewise that we want the (0, 0)-th anchor of the second
151-
layer to lie at (.25, .25) in image space. By default, if
152-
anchor_offsets=None, then they are set to be half of the corresponding
153-
anchor stride. The pairs can also be specified as dynamic tf.int or
154-
tf.float numbers, e.g. for variable shape input images.
155166
156167
Returns:
157168
boxes: a BoxList holding a collection of N anchor boxes
@@ -168,13 +179,25 @@ def _generate(self,
168179
if not all([isinstance(list_item, tuple) and len(list_item) == 2
169180
for list_item in feature_map_shape_list]):
170181
raise ValueError('feature_map_shape_list must be a list of pairs.')
171-
if not anchor_strides:
172-
anchor_strides = [(tf.to_float(im_height) / tf.to_float(pair[0]),
173-
tf.to_float(im_width) / tf.to_float(pair[1]))
182+
183+
im_height = tf.to_float(im_height)
184+
im_width = tf.to_float(im_width)
185+
186+
if not self._anchor_strides:
187+
anchor_strides = [(1.0 / tf.to_float(pair[0]), 1.0 / tf.to_float(pair[1]))
174188
for pair in feature_map_shape_list]
175-
if not anchor_offsets:
189+
else:
190+
anchor_strides = [(tf.to_float(stride[0]) / im_height,
191+
tf.to_float(stride[1]) / im_width)
192+
for stride in self._anchor_strides]
193+
if not self._anchor_offsets:
176194
anchor_offsets = [(0.5 * stride[0], 0.5 * stride[1])
177195
for stride in anchor_strides]
196+
else:
197+
anchor_offsets = [(tf.to_float(offset[0]) / im_height,
198+
tf.to_float(offset[1]) / im_width)
199+
for offset in self._anchor_offsets]
200+
178201
for arg, arg_name in zip([anchor_strides, anchor_offsets],
179202
['anchor_strides', 'anchor_offsets']):
180203
if not (isinstance(arg, list) and len(arg) == len(self._box_specs)):
@@ -185,8 +208,13 @@ def _generate(self,
185208
raise ValueError('%s must be a list of pairs.' % arg_name)
186209

187210
anchor_grid_list = []
188-
min_im_shape = tf.to_float(tf.minimum(im_height, im_width))
189-
base_anchor_size = min_im_shape * self._base_anchor_size
211+
min_im_shape = tf.minimum(im_height, im_width)
212+
scale_height = min_im_shape / im_height
213+
scale_width = min_im_shape / im_width
214+
base_anchor_size = [
215+
scale_height * self._base_anchor_size[0],
216+
scale_width * self._base_anchor_size[1]
217+
]
190218
for grid_size, scales, aspect_ratios, stride, offset in zip(
191219
feature_map_shape_list, self._scales, self._aspect_ratios,
192220
anchor_strides, anchor_offsets):
@@ -204,12 +232,9 @@ def _generate(self,
204232
if num_anchors is None:
205233
num_anchors = concatenated_anchors.num_boxes()
206234
if self._clip_window is not None:
207-
clip_window = tf.multiply(
208-
tf.to_float([im_height, im_width, im_height, im_width]),
209-
self._clip_window)
210235
concatenated_anchors = box_list_ops.clip_to_window(
211-
concatenated_anchors, clip_window, filter_nonoverlapping=False)
212-
# TODO: make reshape an option for the clip_to_window op
236+
concatenated_anchors, self._clip_window, filter_nonoverlapping=False)
237+
# TODO(jonathanhuang): make reshape an option for the clip_to_window op
213238
concatenated_anchors.set(
214239
tf.reshape(concatenated_anchors.get(), [num_anchors, 4]))
215240

@@ -223,8 +248,12 @@ def _generate(self,
223248
def create_ssd_anchors(num_layers=6,
224249
min_scale=0.2,
225250
max_scale=0.95,
226-
aspect_ratios=(1.0, 2.0, 3.0, 1.0/2, 1.0/3),
251+
scales=None,
252+
aspect_ratios=(1.0, 2.0, 3.0, 1.0 / 2, 1.0 / 3),
253+
interpolated_scale_aspect_ratio=1.0,
227254
base_anchor_size=None,
255+
anchor_strides=None,
256+
anchor_offsets=None,
228257
reduce_boxes_in_lowest_layer=True):
229258
"""Creates MultipleGridAnchorGenerator for SSD anchors.
230259
@@ -244,9 +273,33 @@ def create_ssd_anchors(num_layers=6,
244273
grid sizes passed in at generation time)
245274
min_scale: scale of anchors corresponding to finest resolution (float)
246275
max_scale: scale of anchors corresponding to coarsest resolution (float)
276+
scales: As list of anchor scales to use. When not None and not emtpy,
277+
min_scale and max_scale are not used.
247278
aspect_ratios: list or tuple of (float) aspect ratios to place on each
248279
grid point.
280+
interpolated_scale_aspect_ratio: An additional anchor is added with this
281+
aspect ratio and a scale interpolated between the scale for a layer
282+
and the scale for the next layer (1.0 for the last layer).
283+
This anchor is not included if this value is 0.
249284
base_anchor_size: base anchor size as [height, width].
285+
The height and width values are normalized to the minimum dimension of the
286+
input height and width, so that when the base anchor height equals the
287+
base anchor width, the resulting anchor is square even if the input image
288+
is not square.
289+
anchor_strides: list of pairs of strides in pixels (in y and x directions
290+
respectively). For example, setting anchor_strides=[(25, 25), (50, 50)]
291+
means that we want the anchors corresponding to the first layer to be
292+
strided by 25 pixels and those in the second layer to be strided by 50
293+
pixels in both y and x directions. If anchor_strides=None, they are set to
294+
be the reciprocal of the corresponding feature map shapes.
295+
anchor_offsets: list of pairs of offsets in pixels (in y and x directions
296+
respectively). The offset specifies where we want the center of the
297+
(0, 0)-th anchor to lie for each layer. For example, setting
298+
anchor_offsets=[(10, 10), (20, 20)]) means that we want the
299+
(0, 0)-th anchor of the first layer to lie at (10, 10) in pixel space
300+
and likewise that we want the (0, 0)-th anchor of the second layer to lie
301+
at (25, 25) in pixel space. If anchor_offsets=None, then they are set to
302+
be half of the corresponding anchor stride.
250303
reduce_boxes_in_lowest_layer: a boolean to indicate whether the fixed 3
251304
boxes per location is used in the lowest layer.
252305
@@ -257,8 +310,14 @@ def create_ssd_anchors(num_layers=6,
257310
base_anchor_size = [1.0, 1.0]
258311
base_anchor_size = tf.constant(base_anchor_size, dtype=tf.float32)
259312
box_specs_list = []
260-
scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
261-
for i in range(num_layers)] + [1.0]
313+
if scales is None or not scales:
314+
scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
315+
for i in range(num_layers)] + [1.0]
316+
else:
317+
# Add 1.0 to the end, which will only be used in scale_next below and used
318+
# for computing an interpolated scale for the largest scale in the list.
319+
scales += [1.0]
320+
262321
for layer, scale, scale_next in zip(
263322
range(num_layers), scales[:-1], scales[1:]):
264323
layer_box_specs = []
@@ -267,7 +326,13 @@ def create_ssd_anchors(num_layers=6,
267326
else:
268327
for aspect_ratio in aspect_ratios:
269328
layer_box_specs.append((scale, aspect_ratio))
270-
if aspect_ratio == 1.0:
271-
layer_box_specs.append((np.sqrt(scale*scale_next), 1.0))
329+
# Add one more anchor, with a scale between the current scale, and the
330+
# scale for the next layer, with a specified aspect ratio (1.0 by
331+
# default).
332+
if interpolated_scale_aspect_ratio > 0.0:
333+
layer_box_specs.append((np.sqrt(scale*scale_next),
334+
interpolated_scale_aspect_ratio))
272335
box_specs_list.append(layer_box_specs)
273-
return MultipleGridAnchorGenerator(box_specs_list, base_anchor_size)
336+
337+
return MultipleGridAnchorGenerator(box_specs_list, base_anchor_size,
338+
anchor_strides, anchor_offsets)

0 commit comments

Comments
 (0)