20
20
from absl import flags
21
21
import cv2
22
22
import numpy as np
23
+ import pandas as pd
23
24
from official .projects .waste_identification_ml .model_inference import color_and_property_extractor
25
+ from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import big_query_ops
24
26
from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import feature_extraction
25
27
from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import ffmpeg_ops
26
28
from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import mask_bbox_saver
29
+ from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import object_tracking
30
+ from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import object_tracking_postprocessing
27
31
from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import triton_server_inference
28
32
from official .projects .waste_identification_ml .Triton_TF_Cloud_Deployment .client import utils
29
33
79
83
"bq_dataset_id" , "Circularnet_dataset" , "Big query dataset ID"
80
84
)
81
85
82
- TABLE_ID = flags .DEFINE_string (
86
+ BQ_TABLE_ID = flags .DEFINE_string (
83
87
"bq_table_id" , "Circularnet_table" , "BigQuery Table ID for features data"
84
88
)
85
89
98
102
AREA_THRESHOLD = None
99
103
HEIGHT_TRACKING = 300
100
104
WIDTH_TRACKING = 300
105
+ CIRCLE_RADIUS = 7
106
+ FONT = cv2 .FONT_HERSHEY_SIMPLEX
107
+ FONTSCALE = 1
108
+ COLOR = (255 , 0 , 0 )
101
109
102
110
103
111
def main (_ ) -> None :
@@ -119,16 +127,6 @@ def main(_) -> None:
119
127
prediction_folder = os .path .basename (input_directory ) + "_prediction"
120
128
os .makedirs (prediction_folder , exist_ok = True )
121
129
122
- if TRACKING_VISUALIZATION .value :
123
- # Create a folder to troubleshoot tracking results.
124
- tracking_folder = os .path .basename (input_directory ) + "_tracking"
125
- os .makedirs (tracking_folder , exist_ok = True )
126
-
127
- if CROPPED_OBJECTS .value :
128
- # Create a folder to save cropped objects per category from the output.
129
- cropped_obj_folder = os .path .basename (input_directory ) + "_cropped_objects"
130
- os .makedirs (cropped_obj_folder , exist_ok = True )
131
-
132
130
# Create a log directory and a logger for logging.
133
131
log_name = os .path .basename (INPUT_DIRECTORY .value )
134
132
log_folder = os .path .join (os .getcwd (), "logs" )
@@ -145,10 +143,12 @@ def main(_) -> None:
145
143
tracking_images = {}
146
144
features_set = []
147
145
image_plot = None
146
+ agg_features = None
147
+ tracking_features = None
148
148
149
149
for frame , image_path in enumerate (files , start = 1 ):
150
150
# Prepare an input for a Triton model server from an image.
151
- logger .info (f"\n Processing { os .path .basename (image_path )} " )
151
+ logger .info (f"Processing { os .path .basename (image_path )} " )
152
152
try :
153
153
inputs , original_image , _ = (
154
154
triton_server_inference .prepare_image (
@@ -188,6 +188,10 @@ def main(_) -> None:
188
188
189
189
if any (filtered_indices ):
190
190
result = utils .filter_detections (result , filtered_indices )
191
+ logger .info (
192
+ "Total predictions after"
193
+ f" thresholding:{ result ['num_detections' ][0 ]} "
194
+ )
191
195
else :
192
196
logger .info ("Zero predictions after threshold." )
193
197
continue
@@ -199,8 +203,8 @@ def main(_) -> None:
199
203
# Convert bbox coordinates into normalized coordinates.
200
204
if result ["num_detections" ][0 ]:
201
205
result ["normalized_boxes" ] = result ["detection_boxes" ].copy ()
202
- result ["normalized_boxes" ][:, :, [0 , 2 ]] /= HEIGHT
203
- result ["normalized_boxes" ][:, :, [1 , 3 ]] /= WIDTH
206
+ result ["normalized_boxes" ][:, :, [0 , 2 ]] /= HEIGHT . value
207
+ result ["normalized_boxes" ][:, :, [1 , 3 ]] /= WIDTH . value
204
208
result ["detection_boxes" ] = (
205
209
result ["detection_boxes" ].round ().astype (int )
206
210
)
@@ -255,6 +259,7 @@ def main(_) -> None:
255
259
category_index = category_index ,
256
260
threshold = PREDICTION_THRESHOLD .value ,
257
261
)
262
+ logger .info ("Visualization saved." )
258
263
except (KeyError , IndexError , TypeError , ValueError ) as e :
259
264
logger .info ("Issue in saving visualization of results." )
260
265
logger .exception ("Exception occured:" , e )
@@ -308,9 +313,122 @@ def main(_) -> None:
308
313
features ["detection_classes_names" ] = result ["detection_classes_names" ][0 ]
309
314
features ["color" ] = generic_color_names
310
315
features_set .append (features )
311
- except (KeyError , IndexError , TypeError , ValueError ) as e :
316
+ logger .info ("Features extracted.\n " )
317
+ except (KeyError , IndexError , TypeError , ValueError ):
312
318
logger .info ("Failed to extract properties." )
313
- logger .exception ("Exception occured:" , e )
319
+
320
+ try :
321
+ if features_set :
322
+ features_df = pd .concat (features_set , ignore_index = True )
323
+
324
+ # Apply object tracking to the features.
325
+ tracking_features = object_tracking .apply_tracking (
326
+ features_df ,
327
+ search_range_x = SEARCH_RANGE_X .value ,
328
+ search_range_y = SEARCH_RANGE_Y .value ,
329
+ memory = MEMORY .value ,
330
+ )
331
+
332
+ # Process the tracking results to remove errors.
333
+ agg_features = object_tracking_postprocessing .process_tracking_result (
334
+ tracking_features
335
+ )
336
+ counts = agg_features .groupby ("detection_classes_names" ).size ()
337
+ counts .to_frame ().to_csv (os .path .join (os .getcwd (), "count.csv" ))
338
+ logger .info ("Object tracking applied." )
339
+ except (KeyError , IndexError , TypeError , ValueError ):
340
+ logger .info ("Failed to apply object tracking." )
341
+
342
+ try :
343
+ if TRACKING_VISUALIZATION .value :
344
+ # Create a folder to save the tracking visualization.
345
+ tracking_folder = os .path .basename (input_directory ) + "_tracking"
346
+ os .makedirs (tracking_folder , exist_ok = True )
347
+
348
+ # Save the tracking results as an image files.
349
+ output_folder = mask_bbox_saver .visualize_tracking_results (
350
+ tracking_features = tracking_features ,
351
+ tracking_images = tracking_images ,
352
+ tracking_folder = tracking_folder ,
353
+ )
354
+ logger .info (f"Tracking visualization saved to { output_folder } ." )
355
+
356
+ # Move the tracking visualization to the output directory.
357
+ commands = [
358
+ f"gsutil -m cp -r { output_folder } { OUTPUT_DIRECTORY .value } " ,
359
+ f"rm -r { output_folder } " ,
360
+ ]
361
+ combined_command_1 = " && " .join (commands )
362
+ subprocess .run (combined_command_1 , shell = True , check = True )
363
+ logger .info ("Tracking visualization saved." )
364
+ except (KeyError , IndexError , TypeError , ValueError ):
365
+ logger .info ("Failed to visualize tracking results." )
366
+
367
+ try :
368
+ if CROPPED_OBJECTS .value :
369
+ cropped_obj_folder = mask_bbox_saver .save_cropped_objects (
370
+ agg_features = agg_features ,
371
+ input_directory = input_directory ,
372
+ height_tracking = HEIGHT_TRACKING ,
373
+ width_tracking = WIDTH_TRACKING ,
374
+ resize_bbox = utils .resize_bbox ,
375
+ )
376
+ logger .info ("Cropped objects saved in %s" , cropped_obj_folder )
377
+
378
+ # Move the cropped objects to the output directory.
379
+ commands = [
380
+ f"gsutil -m cp -r { cropped_obj_folder } { OUTPUT_DIRECTORY .value } " ,
381
+ f"rm -r { cropped_obj_folder } " ,
382
+ ]
383
+
384
+ combined_command_2 = " && " .join (commands )
385
+ subprocess .run (combined_command_2 , shell = True , check = True )
386
+ logger .info ("Cropped objects saved." )
387
+ except (KeyError , IndexError , TypeError , ValueError ):
388
+ logger .info ("Issue in cropping objects" )
389
+ logger .info ("Failed to crop objects." )
390
+ logger .exception ("Exception occured:" , e )
391
+
392
+ try :
393
+ # Create a big query table to store the aggregated features data.
394
+ big_query_ops .create_table (
395
+ PROJECT_ID .value ,
396
+ BQ_DATASET_ID .value ,
397
+ BQ_TABLE_ID .value ,
398
+ overwrite = OVERWRITE .value ,
399
+ )
400
+ logger .info ("Successfully created table." )
401
+ except (KeyError , IndexError , TypeError , ValueError ):
402
+ logger .info ("Issue in creation of table" )
403
+ return
404
+
405
+ try :
406
+ # Ingest the aggregated features data into the big query table.
407
+ big_query_ops .ingest_data (
408
+ agg_features , PROJECT_ID .value , BQ_DATASET_ID .value , BQ_TABLE_ID .value
409
+ )
410
+ logger .info ("Data ingested successfully." )
411
+ except (KeyError , IndexError , TypeError , ValueError ):
412
+ logger .info ("Issue in data ingestion." )
413
+ return
414
+
415
+ try :
416
+ # Move the folders to the destination bucket.
417
+ commands = [
418
+ (
419
+ "gsutil -m cp -r"
420
+ f" { os .path .basename (input_directory )} { OUTPUT_DIRECTORY .value } "
421
+ ),
422
+ f"rm -r { os .path .basename (input_directory )} " ,
423
+ f"gsutil -m cp -r { prediction_folder } { OUTPUT_DIRECTORY .value } " ,
424
+ f"rm -r { prediction_folder } " ,
425
+ ]
426
+
427
+ combined_command_3 = " && " .join (commands )
428
+ subprocess .run (combined_command_3 , shell = True , check = True )
429
+ logger .info ("Successfully moved to destination bucket" )
430
+ except (KeyError , IndexError , TypeError , ValueError ):
431
+ logger .info ("Issue in moving folders to destination bucket" )
314
432
315
433
316
434
if __name__ == "__main__" :
0 commit comments