Skip to content

Commit 677bcef

Browse files
No public description
PiperOrigin-RevId: 764603158
1 parent b41a080 commit 677bcef

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,116 @@ def resize_each_mask(
342342
)
343343
combined_masks.append(mask)
344344
return np.array(combined_masks)
345+
346+
347+
def extract_and_resize_objects(
348+
results: Mapping[str, Any],
349+
masks: str,
350+
boxes: str,
351+
image: np.ndarray,
352+
resize_factor: float = 0.5,
353+
) -> Sequence[np.ndarray]:
354+
"""Extract and resize objects from the detection results.
355+
356+
Args:
357+
results: The detection results from the model.
358+
masks: The masks to extract objects from.
359+
boxes: The bounding boxes of the objects.
360+
image: The image to extract objects from.
361+
resize_factor: The factor by which to resize the objects.
362+
363+
Returns:
364+
A list of cropped objects.
365+
"""
366+
cropped_objects = []
367+
368+
for i, mask in enumerate(results[masks]):
369+
ymin, xmin, ymax, xmax = results[boxes][0][i]
370+
mask = np.expand_dims(mask, axis=-1)
371+
372+
# Crop the object using the mask and bounding box
373+
cropped_object = np.where(
374+
mask[ymin:ymax, xmin:xmax], image[ymin:ymax, xmin:xmax], 0
375+
)
376+
377+
# Calculate new dimensions
378+
new_width = int(cropped_object.shape[1] * resize_factor)
379+
new_height = int(cropped_object.shape[0] * resize_factor)
380+
cropped_object = cv2.resize(
381+
cropped_object, (new_width, new_height), interpolation=cv2.INTER_AREA
382+
)
383+
cropped_objects.append(cropped_object)
384+
385+
return cropped_objects
386+
387+
388+
def adjust_image_size(
389+
height: int, width: int, min_size: int
390+
) -> tuple[int, int]:
391+
"""Adjust the image size to ensure both dimensions are at least of min_size.
392+
393+
Args:
394+
height: The height of the image.
395+
width: The width of the image.
396+
min_size: Minimum size of the image dimension needed.
397+
398+
Returns:
399+
The adjusted height and width of the image.
400+
"""
401+
if height < min_size or width < min_size:
402+
return height, width
403+
404+
# Calculate the scale factor to ensure both dimensions remain at least 1024
405+
scale_factor = min(height / min_size, width / min_size)
406+
return int(height / scale_factor), int(width / scale_factor)
407+
408+
409+
def filter_detections(
410+
results: Mapping[str, np.ndarray],
411+
valid_indices: Sequence[int] | Sequence[bool],
412+
) -> Mapping[str, np.ndarray]:
413+
"""Filter the detection results based on the valid indices.
414+
415+
Args:
416+
results: The detection results from the model.
417+
valid_indices: The indices of the valid detections.
418+
419+
Returns:
420+
The filtered detection results.
421+
"""
422+
if np.array(valid_indices).dtype == bool:
423+
new_num_detections = int(np.sum(valid_indices))
424+
else:
425+
new_num_detections = len(valid_indices)
426+
427+
# Define the keys to filter
428+
keys_to_filter = [
429+
'detection_masks',
430+
'detection_masks_resized',
431+
'detection_masks_reframed',
432+
'detection_classes',
433+
'detection_boxes',
434+
'normalized_boxes',
435+
'detection_scores',
436+
]
437+
438+
filtered_output = {}
439+
440+
for key in keys_to_filter:
441+
if key in results:
442+
if key == 'detection_masks':
443+
filtered_output[key] = results[key][:, valid_indices, :, :]
444+
elif key in ['detection_masks_resized', 'detection_masks_reframed']:
445+
filtered_output[key] = results[key][valid_indices, :, :]
446+
elif key in ['detection_boxes', 'normalized_boxes']:
447+
filtered_output[key] = results[key][:, valid_indices, :]
448+
elif key in [
449+
'detection_classes',
450+
'detection_scores',
451+
'detection_classes_names',
452+
]:
453+
filtered_output[key] = results[key][:, valid_indices]
454+
filtered_output['image_info'] = results['image_info']
455+
filtered_output['num_detections'] = np.array([new_num_detections])
456+
457+
return filtered_output

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,115 @@ def test_keeps_biggest_mask(self):
113113
# Expect only the largest mask (index 2) to remain
114114
self.assertEqual(result, [2])
115115

116+
def test_filter_with_boolean_indices(self):
117+
results = {
118+
'detection_masks': np.random.rand(1, 3, 5, 5),
119+
'detection_masks_resized': np.random.rand(3, 5, 5),
120+
'detection_boxes': np.random.rand(1, 3, 4),
121+
'detection_classes': np.array([[1, 2, 3]]),
122+
'detection_scores': np.array([[0.9, 0.8, 0.3]]),
123+
'image_info': np.array([[640, 480]]),
124+
}
125+
126+
valid_indices = [True, False, True]
127+
128+
output = utils.filter_detections(results, valid_indices)
129+
130+
self.assertEqual(output['detection_masks'].shape[1], 2)
131+
self.assertEqual(output['detection_masks_resized'].shape[0], 2)
132+
self.assertEqual(output['detection_boxes'].shape[1], 2)
133+
self.assertEqual(output['detection_classes'].shape[1], 2)
134+
self.assertEqual(output['detection_scores'].shape[1], 2)
135+
self.assertTrue(np.array_equal(output['image_info'], results['image_info']))
136+
self.assertEqual(output['num_detections'][0], 2)
137+
138+
def test_filter_with_integer_indices(self):
139+
results = {
140+
'detection_masks': np.random.rand(1, 4, 5, 5),
141+
'detection_masks_resized': np.random.rand(4, 5, 5),
142+
'detection_boxes': np.random.rand(1, 4, 4),
143+
'detection_classes': np.array([[1, 2, 3, 4]]),
144+
'detection_scores': np.array([[0.9, 0.8, 0.3, 0.6]]),
145+
'image_info': np.array([[640, 480]]),
146+
}
147+
148+
valid_indices = [0, 2] # Keep detections at index 0 and 2
149+
150+
output = utils.filter_detections(results, valid_indices)
151+
152+
self.assertEqual(output['detection_masks'].shape[1], 2)
153+
self.assertEqual(output['detection_masks_resized'].shape[0], 2)
154+
self.assertEqual(output['detection_boxes'].shape[1], 2)
155+
self.assertEqual(output['detection_classes'].shape[1], 2)
156+
self.assertEqual(output['detection_scores'].shape[1], 2)
157+
self.assertEqual(output['num_detections'][0], 2)
158+
159+
def test_both_dimensions_below_min_size(self):
160+
height, width, min_size = 800, 900, 1024
161+
162+
result = utils.adjust_image_size(height, width, min_size)
163+
164+
self.assertEqual(result, (800, 900)) # No scaling should happen
165+
166+
def test_height_below_min_size(self):
167+
height, width, min_size = 900, 1200, 1024
168+
169+
result = utils.adjust_image_size(height, width, min_size)
170+
171+
self.assertEqual(result, (900, 1200)) # No scaling
172+
173+
def test_width_below_min_size(self):
174+
height, width, min_size = 1300, 800, 1024
175+
176+
result = utils.adjust_image_size(height, width, min_size)
177+
178+
self.assertEqual(result, (1300, 800)) # No scaling
179+
180+
def test_both_dimensions_above_min_size(self):
181+
height, width, min_size = 2048, 1536, 1024
182+
expected_scale = min(height / min_size, width / min_size)
183+
expected_height = int(height / expected_scale)
184+
expected_width = int(width / expected_scale)
185+
186+
result = utils.adjust_image_size(height, width, min_size)
187+
188+
self.assertEqual(result, (expected_height, expected_width))
189+
190+
def test_exact_min_size(self):
191+
height, width, min_size = 1024, 1024, 1024
192+
193+
result = utils.adjust_image_size(height, width, min_size)
194+
195+
self.assertEqual(result, (1024, 1024)) # Already meets the requirement
196+
197+
def test_extract_and_resize_single_object(self):
198+
image = np.ones((10, 10, 3), dtype=np.uint8) * 255 # white image
199+
200+
# Define a simple binary mask (1 in a 4x4 box)
201+
mask = np.zeros((10, 10), dtype=np.uint8)
202+
mask[2:6, 3:7] = 1
203+
204+
# Box coordinates match the mask
205+
boxes = np.array([[[2, 3, 6, 7]]], dtype=np.int32) # shape (1, 1, 4)
206+
207+
results = {'masks': [mask], 'boxes': boxes}
208+
209+
cropped_objects = utils.extract_and_resize_objects(
210+
results, 'masks', 'boxes', image, resize_factor=0.5
211+
)
212+
213+
self.assertEqual(len(cropped_objects), 1)
214+
obj = cropped_objects[0]
215+
216+
# Original crop size is (4, 4), so resized should be (2, 2)
217+
self.assertEqual(obj.shape[:2], (2, 2))
218+
219+
# Should still be 3 channels
220+
self.assertEqual(obj.shape[2], 3)
221+
222+
# The output pixels in mask area should be non-zero
223+
self.assertTrue(np.any(obj > 0))
224+
116225

117226
if __name__ == '__main__':
118227
unittest.main()

0 commit comments

Comments
 (0)