Skip to content

Commit bc60794

Browse files
No public description
PiperOrigin-RevId: 763896973
1 parent 0d1fa0b commit bc60794

File tree

2 files changed

+266
-1
lines changed

2 files changed

+266
-1
lines changed

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils.py

Lines changed: 234 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616

1717
from collections.abc import Mapping, Sequence
1818
import csv
19+
import logging
1920
import os
20-
from typing import TypedDict
21+
from typing import Any, TypedDict
22+
import cv2
2123
import natsort
24+
import numpy as np
25+
import tensorflow as tf, tf_keras
2226

2327

2428
class ItemDict(TypedDict):
@@ -27,6 +31,81 @@ class ItemDict(TypedDict):
2731
supercategory: str
2832

2933

34+
def _reframe_image_corners_relative_to_boxes(boxes: tf.Tensor) -> tf.Tensor:
35+
"""Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
36+
37+
The local coordinate frame of each box is assumed to be relative to
38+
its own for corners.
39+
40+
Args:
41+
boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
42+
coordinates in relative coordinate space of each bounding box.
43+
44+
Returns:
45+
reframed_boxes: Reframes boxes with same shape as input.
46+
"""
47+
ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])
48+
49+
height = tf.maximum(ymax - ymin, 1e-4)
50+
width = tf.maximum(xmax - xmin, 1e-4)
51+
52+
ymin_out = (0 - ymin) / height
53+
xmin_out = (0 - xmin) / width
54+
ymax_out = (1 - ymin) / height
55+
xmax_out = (1 - xmin) / width
56+
return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)
57+
58+
59+
def _reframe_box_masks_to_image_masks(
60+
box_masks: tf.Tensor,
61+
boxes: tf.Tensor,
62+
image_height: int,
63+
image_width: int,
64+
resize_method='bilinear'
65+
) -> tf.Tensor:
66+
"""Transforms the box masks back to full image masks.
67+
68+
Embeds masks in bounding boxes of larger masks whose shapes correspond to
69+
image shape.
70+
Args:
71+
box_masks: A tensor of size [num_masks, mask_height, mask_width].
72+
boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
73+
corners. Row i contains [ymin, xmin, ymax, xmax] of the box
74+
corresponding to mask i. Note that the box corners are in
75+
normalized coordinates.
76+
image_height: Image height. The output mask will have the same height as
77+
the image height.
78+
image_width: Image width. The output mask will have the same width as the
79+
image width.
80+
resize_method: The resize method, either 'bilinear' or 'nearest'. Note that
81+
'bilinear' is only respected if box_masks is a float.
82+
Returns:
83+
A tensor of size [num_masks, image_height, image_width] with the same dtype
84+
as `box_masks`.
85+
"""
86+
resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method
87+
def reframe_box_masks_to_image_masks_default():
88+
"""The default function when there are more than 0 box masks."""
89+
90+
num_boxes = tf.shape(box_masks)[0]
91+
box_masks_expanded = tf.expand_dims(box_masks, axis=3)
92+
93+
resized_crops = tf.image.crop_and_resize(
94+
image=box_masks_expanded,
95+
boxes=_reframe_image_corners_relative_to_boxes(boxes),
96+
box_indices=tf.range(num_boxes),
97+
crop_size=[image_height, image_width],
98+
method=resize_method,
99+
extrapolation_value=0)
100+
return tf.cast(resized_crops, box_masks.dtype)
101+
102+
image_masks = tf.cond(
103+
tf.shape(box_masks)[0] > 0,
104+
reframe_box_masks_to_image_masks_default,
105+
lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))
106+
return tf.squeeze(image_masks, axis=3)
107+
108+
30109
def _read_csv_to_list(file_path: str) -> Sequence[str]:
31110
"""Reads a CSV file and returns its contents as a list.
32111
@@ -109,3 +188,157 @@ def files_paths(folder_path):
109188
image_files_full_path = natsort.natsorted(image_files_full_path)
110189

111190
return image_files_full_path
191+
192+
193+
def create_log_file(name: str, logs_folder_path: str) -> logging.Logger:
194+
"""Creates a logger and a log file given the name of the video.
195+
196+
Args:
197+
name: The name of the video.
198+
logs_folder_path: Path to the directory where logs should be saved.
199+
200+
Returns:
201+
logging.Logger: Logger object configured to write logs to the file.
202+
"""
203+
log_file_path = os.path.join(logs_folder_path, f'{name}.log')
204+
logger = logging.getLogger(name)
205+
logger.setLevel(logging.INFO)
206+
file_handler = logging.FileHandler(log_file_path)
207+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
208+
file_handler.setFormatter(formatter)
209+
logger.addHandler(file_handler)
210+
return logger
211+
212+
213+
def reframe_masks(
214+
results: Mapping[str, Any], boxes: str, height: int, width: int
215+
) -> np.ndarray:
216+
"""Reframe the masks to an image size.
217+
218+
Args:
219+
results: The detection results from the model.
220+
boxes: The detection boxes.
221+
height: The height of the original image.
222+
width: The width of the original image.
223+
224+
Returns:
225+
The reframed masks.
226+
"""
227+
detection_masks = results['detection_masks'][0]
228+
detection_boxes = results[boxes][0]
229+
detection_masks_reframed = _reframe_box_masks_to_image_masks(
230+
detection_masks, detection_boxes, height, width
231+
)
232+
detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5, np.uint8)
233+
detection_masks_reframed = detection_masks_reframed.numpy()
234+
return detection_masks_reframed
235+
236+
237+
def _calculate_area(mask: np.ndarray) -> int:
238+
"""Calculate the area of the mask.
239+
240+
Args:
241+
mask: The mask to calculate the area of.
242+
243+
Returns:
244+
The area of the mask.
245+
"""
246+
return np.sum(mask)
247+
248+
249+
def _calculate_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
250+
"""Calculate the intersection over union (IoU) between two masks.
251+
252+
Args:
253+
mask1: The first mask.
254+
mask2: The second mask.
255+
256+
Returns:
257+
The intersection over union (IoU) between the two masks.
258+
"""
259+
intersection = np.logical_and(mask1, mask2).sum()
260+
union = np.logical_or(mask1, mask2).sum()
261+
return intersection / union if union != 0 else 0
262+
263+
264+
def _is_contained(mask1: np.ndarray, mask2: np.ndarray) -> bool:
265+
"""Check if mask1 is entirely contained within mask2.
266+
267+
Args:
268+
mask1: The first mask.
269+
mask2: The second mask.
270+
271+
Returns:
272+
True if mask1 is entirely contained within mask2, False otherwise.
273+
"""
274+
return np.array_equal(np.logical_and(mask1, mask2), mask1)
275+
276+
277+
# TODO: b/416838511 - Reduce the nesting statement in the for loop.
278+
def filter_masks(
279+
masks: np.ndarray,
280+
iou_threshold: float = 0.8,
281+
area_threshold: int | None = None,
282+
) -> Sequence[int]:
283+
"""Filter the overlapping masks.
284+
285+
Filter the masks based on the area and intersection over union (IoU).
286+
287+
Args:
288+
masks: The masks to filter.
289+
iou_threshold: The threshold for the intersection over union (IoU) between
290+
two masks.
291+
area_threshold: The threshold for the area of the mask.
292+
293+
Returns:
294+
The indices of the unique masks.
295+
"""
296+
# Calculate the area for each mask
297+
areas = np.array([_calculate_area(mask) for mask in masks])
298+
299+
# Sort the masks based on area in descending order
300+
sorted_indices = np.argsort(areas)[::-1]
301+
sorted_masks = masks[sorted_indices]
302+
sorted_areas = areas[sorted_indices]
303+
304+
unique_indices = []
305+
306+
for i, mask in enumerate(sorted_masks):
307+
if (
308+
area_threshold is not None and sorted_areas[i] > area_threshold
309+
) or sorted_areas[i] < 4000:
310+
continue
311+
312+
keep = True
313+
for j in range(i):
314+
if _calculate_iou(mask, sorted_masks[j]) > iou_threshold or _is_contained(
315+
mask, sorted_masks[j]
316+
):
317+
keep = False
318+
break
319+
if keep:
320+
unique_indices.append(sorted_indices[i])
321+
322+
return unique_indices
323+
324+
325+
def resize_each_mask(
326+
masks: np.ndarray, target_height: int, target_width: int
327+
) -> np.ndarray:
328+
"""Resize each mask to the target height and width.
329+
330+
Args:
331+
masks: The masks to resize.
332+
target_height: The target height of the resized masks.
333+
target_width: The target width of the resized masks.
334+
335+
Returns:
336+
The resized masks.
337+
"""
338+
combined_masks = []
339+
for i in masks:
340+
mask = cv2.resize(
341+
i, (target_width, target_height), interpolation=cv2.INTER_NEAREST
342+
)
343+
combined_masks.append(mask)
344+
return np.array(combined_masks)

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import tempfile
1717
import unittest
18+
import numpy as np
1819
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils
1920

2021

@@ -81,6 +82,37 @@ def test_files_paths_empty_folder(self):
8182
result = utils.files_paths(temp_dir)
8283
self.assertEqual(result, [])
8384

85+
def test_resize_multiple_masks(self):
86+
# Create two 5x5 masks
87+
mask1 = np.zeros((5, 5), dtype=np.uint8)
88+
mask2 = np.ones((5, 5), dtype=np.uint8)
89+
masks = np.array([mask1, mask2])
90+
91+
resized_masks = utils.resize_each_mask(masks, 3, 3)
92+
93+
self.assertEqual(resized_masks.shape, (2, 3, 3))
94+
self.assertTrue((resized_masks[0] == 0).all())
95+
self.assertTrue((resized_masks[1] == 1).all())
96+
97+
def test_keeps_biggest_mask(self):
98+
# Create larger masks to satisfy the area check (area >= 4000)
99+
mask_small = np.zeros((100, 100), dtype=int)
100+
mask_small[0:40, 0:40] = 1 # Area = 1600 (will be skipped)
101+
102+
mask_medium = np.zeros((100, 100), dtype=int)
103+
mask_medium[0:70, 0:70] = 1 # Area = 4900 (passes area condition)
104+
105+
mask_large = np.zeros((100, 100), dtype=int)
106+
mask_large[:, :] = 1 # Area = 10000 (passes area condition)
107+
108+
masks = np.array([mask_small, mask_medium, mask_large])
109+
110+
# Run filter_masks without specifying area_threshold
111+
result = utils.filter_masks(masks, iou_threshold=0.5)
112+
113+
# Expect only the largest mask (index 2) to remain
114+
self.assertEqual(result, [2])
115+
84116

85117
if __name__ == '__main__':
86118
unittest.main()

0 commit comments

Comments
 (0)