Skip to content

Commit 6aabca2

Browse files
No public description
PiperOrigin-RevId: 765237702
1 parent 63d8b3a commit 6aabca2

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Object tracking using trackpy."""
16+
17+
import pandas as pd
18+
import trackpy as tp
19+
20+
21+
def apply_tracking(
22+
df: pd.DataFrame, search_range_x: int, search_range_y: int, memory: int
23+
) -> pd.DataFrame:
24+
"""Apply tracking to the dataframe.
25+
26+
Args:
27+
df: The dataframe to apply tracking to.
28+
search_range_x: The search range of pixels for tracking along x axis.
29+
search_range_y: The search range of pixels for tracking along y axis.
30+
memory: The number of frames that an object can skip detection in and still
31+
be tracked.
32+
33+
Returns:
34+
The tracking result dataframe.
35+
"""
36+
# Define the columns to examine for tracking
37+
tracking_columns = [
38+
'x',
39+
'y',
40+
'frame',
41+
'bbox_0',
42+
'bbox_1',
43+
'bbox_2',
44+
'bbox_3',
45+
'major_axis_length',
46+
'minor_axis_length',
47+
'perimeter',
48+
]
49+
50+
# Perform the tracking using the relevant columns
51+
track_df = tp.link_df(
52+
df[tracking_columns],
53+
search_range=(search_range_y, search_range_x),
54+
memory=memory,
55+
)
56+
57+
# Preserve original columns not used directly in tracking.
58+
additional_columns = [
59+
'source_name',
60+
'image_name',
61+
'detection_scores',
62+
'detection_classes_names',
63+
'detection_classes',
64+
'color',
65+
'creation_time',
66+
]
67+
track_df[additional_columns] = df[additional_columns]
68+
69+
# Remove unnecessary columns from the tracking result and reset index.
70+
track_df.drop(columns=['frame'], inplace=True)
71+
72+
return track_df
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import pandas as pd
18+
19+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import object_tracking
20+
21+
22+
TEST_IMAGES = pd.DataFrame({
23+
'x': [1, 2, 3],
24+
'y': [4, 5, 6],
25+
'frame': [0, 0, 1],
26+
'bbox_0': [1, 2, 3],
27+
'bbox_1': [4, 5, 6],
28+
'bbox_2': [7, 8, 9],
29+
'bbox_3': [10, 11, 12],
30+
'major_axis_length': [13, 14, 15],
31+
'minor_axis_length': [16, 17, 18],
32+
'perimeter': [19, 20, 21],
33+
'source_name': ['source_name_1', 'source_name_2', 'source_name_3'],
34+
'image_name': ['image_name_1', 'image_name_2', 'image_name_3'],
35+
'detection_scores': [0.1, 0.2, 0.3],
36+
'detection_classes_names': ['class_name_1', 'class_name_2', 'class_name_3'],
37+
'detection_classes': [1, 2, 3],
38+
'color': ['red', 'blue', 'green'],
39+
'creation_time': [100, 200, 300],
40+
})
41+
42+
43+
class ObjectTrackingTest(unittest.TestCase):
44+
45+
def test_object_tracking_retains_columns(self):
46+
"""Tests that object tracking correctly retains columns not used in tracking."""
47+
df = TEST_IMAGES.copy()
48+
expected_columns = [
49+
'source_name',
50+
'image_name',
51+
'detection_scores',
52+
'detection_classes_names',
53+
'detection_classes',
54+
'color',
55+
'creation_time',
56+
]
57+
58+
tracking_result = object_tracking.apply_tracking(df, 10, 10, 10)
59+
60+
self.assertTrue(all(key in tracking_result for key in expected_columns))
61+
62+
def test_object_tracking_drops_columns(self):
63+
"""Tests that object tracking correctly drops unneeded columns."""
64+
df = TEST_IMAGES.copy()
65+
66+
tracking_result = object_tracking.apply_tracking(df, 10, 10, 10)
67+
68+
self.assertNotIn('frame', tracking_result.columns)
69+
70+
if __name__ == '__main__':
71+
unittest.main()

0 commit comments

Comments
 (0)