16
16
17
17
from collections .abc import Mapping , Sequence
18
18
import csv
19
+ import logging
19
20
import os
20
- from typing import TypedDict
21
+ from typing import Any , TypedDict
22
+ import cv2
21
23
import natsort
24
+ import numpy as np
25
+ import tensorflow as tf , tf_keras
22
26
23
27
24
28
class ItemDict (TypedDict ):
@@ -27,6 +31,81 @@ class ItemDict(TypedDict):
27
31
supercategory : str
28
32
29
33
34
+ def _reframe_image_corners_relative_to_boxes (boxes : tf .Tensor ) -> tf .Tensor :
35
+ """Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
36
+
37
+ The local coordinate frame of each box is assumed to be relative to
38
+ its own for corners.
39
+
40
+ Args:
41
+ boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
42
+ coordinates in relative coordinate space of each bounding box.
43
+
44
+ Returns:
45
+ reframed_boxes: Reframes boxes with same shape as input.
46
+ """
47
+ ymin , xmin , ymax , xmax = (boxes [:, 0 ], boxes [:, 1 ], boxes [:, 2 ], boxes [:, 3 ])
48
+
49
+ height = tf .maximum (ymax - ymin , 1e-4 )
50
+ width = tf .maximum (xmax - xmin , 1e-4 )
51
+
52
+ ymin_out = (0 - ymin ) / height
53
+ xmin_out = (0 - xmin ) / width
54
+ ymax_out = (1 - ymin ) / height
55
+ xmax_out = (1 - xmin ) / width
56
+ return tf .stack ([ymin_out , xmin_out , ymax_out , xmax_out ], axis = 1 )
57
+
58
+
59
+ def _reframe_box_masks_to_image_masks (
60
+ box_masks : tf .Tensor ,
61
+ boxes : tf .Tensor ,
62
+ image_height : int ,
63
+ image_width : int ,
64
+ resize_method = 'bilinear'
65
+ ) -> tf .Tensor :
66
+ """Transforms the box masks back to full image masks.
67
+
68
+ Embeds masks in bounding boxes of larger masks whose shapes correspond to
69
+ image shape.
70
+ Args:
71
+ box_masks: A tensor of size [num_masks, mask_height, mask_width].
72
+ boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
73
+ corners. Row i contains [ymin, xmin, ymax, xmax] of the box
74
+ corresponding to mask i. Note that the box corners are in
75
+ normalized coordinates.
76
+ image_height: Image height. The output mask will have the same height as
77
+ the image height.
78
+ image_width: Image width. The output mask will have the same width as the
79
+ image width.
80
+ resize_method: The resize method, either 'bilinear' or 'nearest'. Note that
81
+ 'bilinear' is only respected if box_masks is a float.
82
+ Returns:
83
+ A tensor of size [num_masks, image_height, image_width] with the same dtype
84
+ as `box_masks`.
85
+ """
86
+ resize_method = 'nearest' if box_masks .dtype == tf .uint8 else resize_method
87
+ def reframe_box_masks_to_image_masks_default ():
88
+ """The default function when there are more than 0 box masks."""
89
+
90
+ num_boxes = tf .shape (box_masks )[0 ]
91
+ box_masks_expanded = tf .expand_dims (box_masks , axis = 3 )
92
+
93
+ resized_crops = tf .image .crop_and_resize (
94
+ image = box_masks_expanded ,
95
+ boxes = _reframe_image_corners_relative_to_boxes (boxes ),
96
+ box_indices = tf .range (num_boxes ),
97
+ crop_size = [image_height , image_width ],
98
+ method = resize_method ,
99
+ extrapolation_value = 0 )
100
+ return tf .cast (resized_crops , box_masks .dtype )
101
+
102
+ image_masks = tf .cond (
103
+ tf .shape (box_masks )[0 ] > 0 ,
104
+ reframe_box_masks_to_image_masks_default ,
105
+ lambda : tf .zeros ([0 , image_height , image_width , 1 ], box_masks .dtype ))
106
+ return tf .squeeze (image_masks , axis = 3 )
107
+
108
+
30
109
def _read_csv_to_list (file_path : str ) -> Sequence [str ]:
31
110
"""Reads a CSV file and returns its contents as a list.
32
111
@@ -109,3 +188,157 @@ def files_paths(folder_path):
109
188
image_files_full_path = natsort .natsorted (image_files_full_path )
110
189
111
190
return image_files_full_path
191
+
192
+
193
+ def create_log_file (name : str , logs_folder_path : str ) -> logging .Logger :
194
+ """Creates a logger and a log file given the name of the video.
195
+
196
+ Args:
197
+ name: The name of the video.
198
+ logs_folder_path: Path to the directory where logs should be saved.
199
+
200
+ Returns:
201
+ logging.Logger: Logger object configured to write logs to the file.
202
+ """
203
+ log_file_path = os .path .join (logs_folder_path , f'{ name } .log' )
204
+ logger = logging .getLogger (name )
205
+ logger .setLevel (logging .INFO )
206
+ file_handler = logging .FileHandler (log_file_path )
207
+ formatter = logging .Formatter ('%(asctime)s - %(levelname)s - %(message)s' )
208
+ file_handler .setFormatter (formatter )
209
+ logger .addHandler (file_handler )
210
+ return logger
211
+
212
+
213
+ def reframe_masks (
214
+ results : Mapping [str , Any ], boxes : str , height : int , width : int
215
+ ) -> np .ndarray :
216
+ """Reframe the masks to an image size.
217
+
218
+ Args:
219
+ results: The detection results from the model.
220
+ boxes: The detection boxes.
221
+ height: The height of the original image.
222
+ width: The width of the original image.
223
+
224
+ Returns:
225
+ The reframed masks.
226
+ """
227
+ detection_masks = results ['detection_masks' ][0 ]
228
+ detection_boxes = results [boxes ][0 ]
229
+ detection_masks_reframed = _reframe_box_masks_to_image_masks (
230
+ detection_masks , detection_boxes , height , width
231
+ )
232
+ detection_masks_reframed = tf .cast (detection_masks_reframed > 0.5 , np .uint8 )
233
+ detection_masks_reframed = detection_masks_reframed .numpy ()
234
+ return detection_masks_reframed
235
+
236
+
237
+ def _calculate_area (mask : np .ndarray ) -> int :
238
+ """Calculate the area of the mask.
239
+
240
+ Args:
241
+ mask: The mask to calculate the area of.
242
+
243
+ Returns:
244
+ The area of the mask.
245
+ """
246
+ return np .sum (mask )
247
+
248
+
249
+ def _calculate_iou (mask1 : np .ndarray , mask2 : np .ndarray ) -> float :
250
+ """Calculate the intersection over union (IoU) between two masks.
251
+
252
+ Args:
253
+ mask1: The first mask.
254
+ mask2: The second mask.
255
+
256
+ Returns:
257
+ The intersection over union (IoU) between the two masks.
258
+ """
259
+ intersection = np .logical_and (mask1 , mask2 ).sum ()
260
+ union = np .logical_or (mask1 , mask2 ).sum ()
261
+ return intersection / union if union != 0 else 0
262
+
263
+
264
+ def _is_contained (mask1 : np .ndarray , mask2 : np .ndarray ) -> bool :
265
+ """Check if mask1 is entirely contained within mask2.
266
+
267
+ Args:
268
+ mask1: The first mask.
269
+ mask2: The second mask.
270
+
271
+ Returns:
272
+ True if mask1 is entirely contained within mask2, False otherwise.
273
+ """
274
+ return np .array_equal (np .logical_and (mask1 , mask2 ), mask1 )
275
+
276
+
277
+ # TODO: b/416838511 - Reduce the nesting statement in the for loop.
278
+ def filter_masks (
279
+ masks : np .ndarray ,
280
+ iou_threshold : float = 0.8 ,
281
+ area_threshold : int | None = None ,
282
+ ) -> Sequence [int ]:
283
+ """Filter the overlapping masks.
284
+
285
+ Filter the masks based on the area and intersection over union (IoU).
286
+
287
+ Args:
288
+ masks: The masks to filter.
289
+ iou_threshold: The threshold for the intersection over union (IoU) between
290
+ two masks.
291
+ area_threshold: The threshold for the area of the mask.
292
+
293
+ Returns:
294
+ The indices of the unique masks.
295
+ """
296
+ # Calculate the area for each mask
297
+ areas = np .array ([_calculate_area (mask ) for mask in masks ])
298
+
299
+ # Sort the masks based on area in descending order
300
+ sorted_indices = np .argsort (areas )[::- 1 ]
301
+ sorted_masks = masks [sorted_indices ]
302
+ sorted_areas = areas [sorted_indices ]
303
+
304
+ unique_indices = []
305
+
306
+ for i , mask in enumerate (sorted_masks ):
307
+ if (
308
+ area_threshold is not None and sorted_areas [i ] > area_threshold
309
+ ) or sorted_areas [i ] < 4000 :
310
+ continue
311
+
312
+ keep = True
313
+ for j in range (i ):
314
+ if _calculate_iou (mask , sorted_masks [j ]) > iou_threshold or _is_contained (
315
+ mask , sorted_masks [j ]
316
+ ):
317
+ keep = False
318
+ break
319
+ if keep :
320
+ unique_indices .append (sorted_indices [i ])
321
+
322
+ return unique_indices
323
+
324
+
325
+ def resize_each_mask (
326
+ masks : np .ndarray , target_height : int , target_width : int
327
+ ) -> np .ndarray :
328
+ """Resize each mask to the target height and width.
329
+
330
+ Args:
331
+ masks: The masks to resize.
332
+ target_height: The target height of the resized masks.
333
+ target_width: The target width of the resized masks.
334
+
335
+ Returns:
336
+ The resized masks.
337
+ """
338
+ combined_masks = []
339
+ for i in masks :
340
+ mask = cv2 .resize (
341
+ i , (target_width , target_height ), interpolation = cv2 .INTER_NEAREST
342
+ )
343
+ combined_masks .append (mask )
344
+ return np .array (combined_masks )
0 commit comments