@@ -113,6 +113,115 @@ def test_keeps_biggest_mask(self):
113
113
# Expect only the largest mask (index 2) to remain
114
114
self .assertEqual (result , [2 ])
115
115
116
+ def test_filter_with_boolean_indices (self ):
117
+ results = {
118
+ 'detection_masks' : np .random .rand (1 , 3 , 5 , 5 ),
119
+ 'detection_masks_resized' : np .random .rand (3 , 5 , 5 ),
120
+ 'detection_boxes' : np .random .rand (1 , 3 , 4 ),
121
+ 'detection_classes' : np .array ([[1 , 2 , 3 ]]),
122
+ 'detection_scores' : np .array ([[0.9 , 0.8 , 0.3 ]]),
123
+ 'image_info' : np .array ([[640 , 480 ]]),
124
+ }
125
+
126
+ valid_indices = [True , False , True ]
127
+
128
+ output = utils .filter_detections (results , valid_indices )
129
+
130
+ self .assertEqual (output ['detection_masks' ].shape [1 ], 2 )
131
+ self .assertEqual (output ['detection_masks_resized' ].shape [0 ], 2 )
132
+ self .assertEqual (output ['detection_boxes' ].shape [1 ], 2 )
133
+ self .assertEqual (output ['detection_classes' ].shape [1 ], 2 )
134
+ self .assertEqual (output ['detection_scores' ].shape [1 ], 2 )
135
+ self .assertTrue (np .array_equal (output ['image_info' ], results ['image_info' ]))
136
+ self .assertEqual (output ['num_detections' ][0 ], 2 )
137
+
138
+ def test_filter_with_integer_indices (self ):
139
+ results = {
140
+ 'detection_masks' : np .random .rand (1 , 4 , 5 , 5 ),
141
+ 'detection_masks_resized' : np .random .rand (4 , 5 , 5 ),
142
+ 'detection_boxes' : np .random .rand (1 , 4 , 4 ),
143
+ 'detection_classes' : np .array ([[1 , 2 , 3 , 4 ]]),
144
+ 'detection_scores' : np .array ([[0.9 , 0.8 , 0.3 , 0.6 ]]),
145
+ 'image_info' : np .array ([[640 , 480 ]]),
146
+ }
147
+
148
+ valid_indices = [0 , 2 ] # Keep detections at index 0 and 2
149
+
150
+ output = utils .filter_detections (results , valid_indices )
151
+
152
+ self .assertEqual (output ['detection_masks' ].shape [1 ], 2 )
153
+ self .assertEqual (output ['detection_masks_resized' ].shape [0 ], 2 )
154
+ self .assertEqual (output ['detection_boxes' ].shape [1 ], 2 )
155
+ self .assertEqual (output ['detection_classes' ].shape [1 ], 2 )
156
+ self .assertEqual (output ['detection_scores' ].shape [1 ], 2 )
157
+ self .assertEqual (output ['num_detections' ][0 ], 2 )
158
+
159
+ def test_both_dimensions_below_min_size (self ):
160
+ height , width , min_size = 800 , 900 , 1024
161
+
162
+ result = utils .adjust_image_size (height , width , min_size )
163
+
164
+ self .assertEqual (result , (800 , 900 )) # No scaling should happen
165
+
166
+ def test_height_below_min_size (self ):
167
+ height , width , min_size = 900 , 1200 , 1024
168
+
169
+ result = utils .adjust_image_size (height , width , min_size )
170
+
171
+ self .assertEqual (result , (900 , 1200 )) # No scaling
172
+
173
+ def test_width_below_min_size (self ):
174
+ height , width , min_size = 1300 , 800 , 1024
175
+
176
+ result = utils .adjust_image_size (height , width , min_size )
177
+
178
+ self .assertEqual (result , (1300 , 800 )) # No scaling
179
+
180
+ def test_both_dimensions_above_min_size (self ):
181
+ height , width , min_size = 2048 , 1536 , 1024
182
+ expected_scale = min (height / min_size , width / min_size )
183
+ expected_height = int (height / expected_scale )
184
+ expected_width = int (width / expected_scale )
185
+
186
+ result = utils .adjust_image_size (height , width , min_size )
187
+
188
+ self .assertEqual (result , (expected_height , expected_width ))
189
+
190
+ def test_exact_min_size (self ):
191
+ height , width , min_size = 1024 , 1024 , 1024
192
+
193
+ result = utils .adjust_image_size (height , width , min_size )
194
+
195
+ self .assertEqual (result , (1024 , 1024 )) # Already meets the requirement
196
+
197
+ def test_extract_and_resize_single_object (self ):
198
+ image = np .ones ((10 , 10 , 3 ), dtype = np .uint8 ) * 255 # white image
199
+
200
+ # Define a simple binary mask (1 in a 4x4 box)
201
+ mask = np .zeros ((10 , 10 ), dtype = np .uint8 )
202
+ mask [2 :6 , 3 :7 ] = 1
203
+
204
+ # Box coordinates match the mask
205
+ boxes = np .array ([[[2 , 3 , 6 , 7 ]]], dtype = np .int32 ) # shape (1, 1, 4)
206
+
207
+ results = {'masks' : [mask ], 'boxes' : boxes }
208
+
209
+ cropped_objects = utils .extract_and_resize_objects (
210
+ results , 'masks' , 'boxes' , image , resize_factor = 0.5
211
+ )
212
+
213
+ self .assertEqual (len (cropped_objects ), 1 )
214
+ obj = cropped_objects [0 ]
215
+
216
+ # Original crop size is (4, 4), so resized should be (2, 2)
217
+ self .assertEqual (obj .shape [:2 ], (2 , 2 ))
218
+
219
+ # Should still be 3 channels
220
+ self .assertEqual (obj .shape [2 ], 3 )
221
+
222
+ # The output pixels in mask area should be non-zero
223
+ self .assertTrue (np .any (obj > 0 ))
224
+
116
225
117
226
if __name__ == '__main__' :
118
227
unittest .main ()
0 commit comments