Skip to content

Commit d09d212

Browse files
No public description
PiperOrigin-RevId: 767700052
1 parent 6597bde commit d09d212

File tree

4 files changed

+143
-19
lines changed

4 files changed

+143
-19
lines changed

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/big_query_ops_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ def test_schema_definition(self):
2424
("source_name", "STRING", "REQUIRED"),
2525
("image_name", "STRING", "REQUIRED"),
2626
("detection_scores", "FLOAT", "REQUIRED"),
27-
("color", "STRING", "REQUIRED"),
2827
("creation_time", "STRING", "REQUIRED"),
28+
("bbox_0", "INTEGER", "REQUIRED"),
29+
("bbox_1", "INTEGER", "REQUIRED"),
30+
("bbox_2", "INTEGER", "REQUIRED"),
31+
("bbox_3", "INTEGER", "REQUIRED"),
2932
("detection_classes", "INTEGER", "REQUIRED"),
3033
("detection_classes_names", "STRING", "REQUIRED"),
3134
]

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/biq_query_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@
3333
bigquery.SchemaField("source_name", "STRING", mode="REQUIRED"),
3434
bigquery.SchemaField("image_name", "STRING", mode="REQUIRED"),
3535
bigquery.SchemaField("detection_scores", "FLOAT", mode="REQUIRED"),
36-
bigquery.SchemaField("color", "STRING", mode="REQUIRED"),
3736
bigquery.SchemaField("creation_time", "STRING", mode="REQUIRED"),
37+
bigquery.SchemaField("bbox_0", "INTEGER", mode="REQUIRED"),
38+
bigquery.SchemaField("bbox_1", "INTEGER", mode="REQUIRED"),
39+
bigquery.SchemaField("bbox_2", "INTEGER", mode="REQUIRED"),
40+
bigquery.SchemaField("bbox_3", "INTEGER", mode="REQUIRED"),
3841
bigquery.SchemaField("detection_classes", "INTEGER", mode="REQUIRED"),
3942
bigquery.SchemaField("detection_classes_names", "STRING", mode="REQUIRED"),
4043
]

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/inference_pipeline.py

Lines changed: 134 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
from absl import flags
2121
import cv2
2222
import numpy as np
23+
import pandas as pd
2324
from official.projects.waste_identification_ml.model_inference import color_and_property_extractor
25+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import big_query_ops
2426
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import feature_extraction
2527
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import ffmpeg_ops
2628
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import mask_bbox_saver
29+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import object_tracking
30+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import object_tracking_postprocessing
2731
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import triton_server_inference
2832
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils
2933

@@ -79,7 +83,7 @@
7983
"bq_dataset_id", "Circularnet_dataset", "Big query dataset ID"
8084
)
8185

82-
TABLE_ID = flags.DEFINE_string(
86+
BQ_TABLE_ID = flags.DEFINE_string(
8387
"bq_table_id", "Circularnet_table", "BigQuery Table ID for features data"
8488
)
8589

@@ -98,6 +102,10 @@
98102
AREA_THRESHOLD = None
99103
HEIGHT_TRACKING = 300
100104
WIDTH_TRACKING = 300
105+
CIRCLE_RADIUS = 7
106+
FONT = cv2.FONT_HERSHEY_SIMPLEX
107+
FONTSCALE = 1
108+
COLOR = (255, 0, 0)
101109

102110

103111
def main(_) -> None:
@@ -119,16 +127,6 @@ def main(_) -> None:
119127
prediction_folder = os.path.basename(input_directory) + "_prediction"
120128
os.makedirs(prediction_folder, exist_ok=True)
121129

122-
if TRACKING_VISUALIZATION.value:
123-
# Create a folder to troubleshoot tracking results.
124-
tracking_folder = os.path.basename(input_directory) + "_tracking"
125-
os.makedirs(tracking_folder, exist_ok=True)
126-
127-
if CROPPED_OBJECTS.value:
128-
# Create a folder to save cropped objects per category from the output.
129-
cropped_obj_folder = os.path.basename(input_directory) + "_cropped_objects"
130-
os.makedirs(cropped_obj_folder, exist_ok=True)
131-
132130
# Create a log directory and a logger for logging.
133131
log_name = os.path.basename(INPUT_DIRECTORY.value)
134132
log_folder = os.path.join(os.getcwd(), "logs")
@@ -145,10 +143,12 @@ def main(_) -> None:
145143
tracking_images = {}
146144
features_set = []
147145
image_plot = None
146+
agg_features = None
147+
tracking_features = None
148148

149149
for frame, image_path in enumerate(files, start=1):
150150
# Prepare an input for a Triton model server from an image.
151-
logger.info(f"\nProcessing {os.path.basename(image_path)}")
151+
logger.info(f"Processing {os.path.basename(image_path)}")
152152
try:
153153
inputs, original_image, _ = (
154154
triton_server_inference.prepare_image(
@@ -188,6 +188,10 @@ def main(_) -> None:
188188

189189
if any(filtered_indices):
190190
result = utils.filter_detections(result, filtered_indices)
191+
logger.info(
192+
"Total predictions after"
193+
f" thresholding:{result['num_detections'][0]}"
194+
)
191195
else:
192196
logger.info("Zero predictions after threshold.")
193197
continue
@@ -199,8 +203,8 @@ def main(_) -> None:
199203
# Convert bbox coordinates into normalized coordinates.
200204
if result["num_detections"][0]:
201205
result["normalized_boxes"] = result["detection_boxes"].copy()
202-
result["normalized_boxes"][:, :, [0, 2]] /= HEIGHT
203-
result["normalized_boxes"][:, :, [1, 3]] /= WIDTH
206+
result["normalized_boxes"][:, :, [0, 2]] /= HEIGHT.value
207+
result["normalized_boxes"][:, :, [1, 3]] /= WIDTH.value
204208
result["detection_boxes"] = (
205209
result["detection_boxes"].round().astype(int)
206210
)
@@ -255,6 +259,7 @@ def main(_) -> None:
255259
category_index=category_index,
256260
threshold=PREDICTION_THRESHOLD.value,
257261
)
262+
logger.info("Visualization saved.")
258263
except (KeyError, IndexError, TypeError, ValueError) as e:
259264
logger.info("Issue in saving visualization of results.")
260265
logger.exception("Exception occured:", e)
@@ -308,9 +313,122 @@ def main(_) -> None:
308313
features["detection_classes_names"] = result["detection_classes_names"][0]
309314
features["color"] = generic_color_names
310315
features_set.append(features)
311-
except (KeyError, IndexError, TypeError, ValueError) as e:
316+
logger.info("Features extracted.\n")
317+
except (KeyError, IndexError, TypeError, ValueError):
312318
logger.info("Failed to extract properties.")
313-
logger.exception("Exception occured:", e)
319+
320+
try:
321+
if features_set:
322+
features_df = pd.concat(features_set, ignore_index=True)
323+
324+
# Apply object tracking to the features.
325+
tracking_features = object_tracking.apply_tracking(
326+
features_df,
327+
search_range_x=SEARCH_RANGE_X.value,
328+
search_range_y=SEARCH_RANGE_Y.value,
329+
memory=MEMORY.value,
330+
)
331+
332+
# Process the tracking results to remove errors.
333+
agg_features = object_tracking_postprocessing.process_tracking_result(
334+
tracking_features
335+
)
336+
counts = agg_features.groupby("detection_classes_names").size()
337+
counts.to_frame().to_csv(os.path.join(os.getcwd(), "count.csv"))
338+
logger.info("Object tracking applied.")
339+
except (KeyError, IndexError, TypeError, ValueError):
340+
logger.info("Failed to apply object tracking.")
341+
342+
try:
343+
if TRACKING_VISUALIZATION.value:
344+
# Create a folder to save the tracking visualization.
345+
tracking_folder = os.path.basename(input_directory) + "_tracking"
346+
os.makedirs(tracking_folder, exist_ok=True)
347+
348+
# Save the tracking results as an image files.
349+
output_folder = mask_bbox_saver.visualize_tracking_results(
350+
tracking_features=tracking_features,
351+
tracking_images=tracking_images,
352+
tracking_folder=tracking_folder,
353+
)
354+
logger.info(f"Tracking visualization saved to {output_folder}.")
355+
356+
# Move the tracking visualization to the output directory.
357+
commands = [
358+
f"gsutil -m cp -r {output_folder} {OUTPUT_DIRECTORY.value}",
359+
f"rm -r {output_folder}",
360+
]
361+
combined_command_1 = " && ".join(commands)
362+
subprocess.run(combined_command_1, shell=True, check=True)
363+
logger.info("Tracking visualization saved.")
364+
except (KeyError, IndexError, TypeError, ValueError):
365+
logger.info("Failed to visualize tracking results.")
366+
367+
try:
368+
if CROPPED_OBJECTS.value:
369+
cropped_obj_folder = mask_bbox_saver.save_cropped_objects(
370+
agg_features=agg_features,
371+
input_directory=input_directory,
372+
height_tracking=HEIGHT_TRACKING,
373+
width_tracking=WIDTH_TRACKING,
374+
resize_bbox=utils.resize_bbox,
375+
)
376+
logger.info("Cropped objects saved in %s", cropped_obj_folder)
377+
378+
# Move the cropped objects to the output directory.
379+
commands = [
380+
f"gsutil -m cp -r {cropped_obj_folder} {OUTPUT_DIRECTORY.value}",
381+
f"rm -r {cropped_obj_folder}",
382+
]
383+
384+
combined_command_2 = " && ".join(commands)
385+
subprocess.run(combined_command_2, shell=True, check=True)
386+
logger.info("Cropped objects saved.")
387+
except (KeyError, IndexError, TypeError, ValueError):
388+
logger.info("Issue in cropping objects")
389+
logger.info("Failed to crop objects.")
390+
logger.exception("Exception occured:", e)
391+
392+
try:
393+
# Create a big query table to store the aggregated features data.
394+
big_query_ops.create_table(
395+
PROJECT_ID.value,
396+
BQ_DATASET_ID.value,
397+
BQ_TABLE_ID.value,
398+
overwrite=OVERWRITE.value,
399+
)
400+
logger.info("Successfully created table.")
401+
except (KeyError, IndexError, TypeError, ValueError):
402+
logger.info("Issue in creation of table")
403+
return
404+
405+
try:
406+
# Ingest the aggregated features data into the big query table.
407+
big_query_ops.ingest_data(
408+
agg_features, PROJECT_ID.value, BQ_DATASET_ID.value, BQ_TABLE_ID.value
409+
)
410+
logger.info("Data ingested successfully.")
411+
except (KeyError, IndexError, TypeError, ValueError):
412+
logger.info("Issue in data ingestion.")
413+
return
414+
415+
try:
416+
# Move the folders to the destination bucket.
417+
commands = [
418+
(
419+
"gsutil -m cp -r"
420+
f" {os.path.basename(input_directory)} {OUTPUT_DIRECTORY.value}"
421+
),
422+
f"rm -r {os.path.basename(input_directory)}",
423+
f"gsutil -m cp -r {prediction_folder} {OUTPUT_DIRECTORY.value}",
424+
f"rm -r {prediction_folder}",
425+
]
426+
427+
combined_command_3 = " && ".join(commands)
428+
subprocess.run(combined_command_3, shell=True, check=True)
429+
logger.info("Successfully moved to destination bucket")
430+
except (KeyError, IndexError, TypeError, ValueError):
431+
logger.info("Issue in moving folders to destination bucket")
314432

315433

316434
if __name__ == "__main__":

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/run_images.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ python inference_pipeline.py \
5858
--score=0.70 \
5959
--search_range_x=150 \
6060
--search_range_y=20 \
61-
--memory=1 \
61+
--memory=10 \
6262
--project_id=waste-identification-ml-330916 \
6363
--bq_dataset_id=circularnet_dataset \
6464
--bq_table_id=circularnet_table \

0 commit comments

Comments
 (0)