Skip to content

Commit 9b31e2b

Browse files
No public description
PiperOrigin-RevId: 755955084
1 parent 72b49f3 commit 9b31e2b

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Extract properties of the mask."""
16+
17+
import numpy as np
18+
import pandas as pd
19+
import skimage.measure
20+
21+
_PROPERTIES = (
22+
'area',
23+
'bbox',
24+
'convex_area',
25+
'bbox_area',
26+
'major_axis_length',
27+
'minor_axis_length',
28+
'eccentricity',
29+
'centroid',
30+
'label',
31+
'mean_intensity',
32+
'max_intensity',
33+
'min_intensity',
34+
'perimeter',
35+
)
36+
37+
38+
def _extract_dataframes(
39+
image: np.ndarray, masks: np.ndarray
40+
) -> list[pd.DataFrame]:
41+
"""Helper function to extract DataFrames from mask properties."""
42+
list_of_df = []
43+
for mask in masks:
44+
mask = np.where(mask, 1, 0)
45+
df = pd.DataFrame(
46+
skimage.measure.regionprops_table(
47+
mask, intensity_image=image, properties=_PROPERTIES
48+
)
49+
)
50+
list_of_df.append(df)
51+
return list_of_df
52+
53+
54+
def extract_properties(
55+
image: np.ndarray, results: dict[str, np.ndarray], masks: str
56+
) -> pd.DataFrame:
57+
"""Extract properties of the mask."""
58+
list_of_df = _extract_dataframes(
59+
image, results[masks]
60+
) # Use the helper function
61+
if not list_of_df: # Handle case where there are no valid masks
62+
return pd.DataFrame(columns=_PROPERTIES)
63+
64+
features = pd.concat(list_of_df, ignore_index=True)
65+
features.rename(
66+
columns={
67+
'centroid-0': 'y',
68+
'centroid-1': 'x',
69+
'bbox-0': 'bbox_0',
70+
'bbox-1': 'bbox_1',
71+
'bbox-2': 'bbox_2',
72+
'bbox-3': 'bbox_3',
73+
},
74+
inplace=True,
75+
)
76+
return features
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import pandas as pd
18+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import feature_extraction
19+
20+
TEST_IMAGE = np.array(
21+
[
22+
[10, 20, 30, 40, 50],
23+
[15, 25, 35, 45, 55],
24+
[20, 30, 40, 50, 60],
25+
[25, 35, 45, 55, 65],
26+
[30, 40, 50, 60, 70],
27+
],
28+
dtype=np.uint8,
29+
)
30+
31+
# Create dummy masks (e.g., two masks)
32+
TEST_MASKS = np.array(
33+
[
34+
[
35+
[0, 0, 0, 0, 0],
36+
[0, 1, 1, 0, 0],
37+
[0, 1, 1, 0, 0],
38+
[0, 0, 0, 0, 0],
39+
[0, 0, 0, 0, 0],
40+
],
41+
[
42+
[0, 0, 0, 0, 0],
43+
[0, 0, 0, 0, 0],
44+
[0, 0, 0, 0, 0],
45+
[0, 0, 1, 1, 0],
46+
[0, 0, 1, 1, 0],
47+
],
48+
],
49+
dtype=np.int32,
50+
)
51+
52+
# Create empty masks (all zeros)
53+
EMPTY_MASKS = np.zeros((2, 5, 5), dtype=np.int32)
54+
55+
# Simulate the results dictionary, assuming masks are under the key 'masks'
56+
TEST_RESULTS = {'masks': TEST_MASKS}
57+
EMPTY_RESULTS = {'masks': EMPTY_MASKS}
58+
59+
60+
# Expected DataFrame for comparison
61+
COMPARISON_DATA = {
62+
'area': [4.0, 4.0],
63+
'bbox_0': [1, 3],
64+
'bbox_1': [1, 2],
65+
'bbox_2': [3, 5],
66+
'bbox_3': [3, 4],
67+
'convex_area': [4.0, 4.0],
68+
'bbox_area': [4.0, 4.0],
69+
'major_axis_length': [2.0, 2.0],
70+
'minor_axis_length': [2.0, 2.0],
71+
'eccentricity': [0.0, 0.0],
72+
'y': [1.5, 3.5],
73+
'x': [1.5, 2.5],
74+
'label': [1, 1],
75+
'mean_intensity': [32.5, 52.5],
76+
'max_intensity': [40.0, 60.0],
77+
'min_intensity': [25.0, 45.0],
78+
'perimeter': [4.0, 4.0],
79+
}
80+
81+
82+
class TestExtractProperties(unittest.TestCase):
83+
84+
def test_extract_properties(self):
85+
# Call the function
86+
features_df = feature_extraction.extract_properties(
87+
TEST_IMAGE, TEST_RESULTS, 'masks'
88+
)
89+
# Check if the DataFrames are equal
90+
self.assertTrue(features_df.equals(pd.DataFrame(COMPARISON_DATA)))
91+
92+
def test_extract_properties_empty_masks(self):
93+
"""Test feature extraction with empty masks."""
94+
features_df = feature_extraction.extract_properties(
95+
TEST_IMAGE, EMPTY_RESULTS, 'masks'
96+
)
97+
# Expecting an empty DataFrame if there are no valid masks
98+
self.assertTrue(features_df.empty)
99+
100+
101+
if __name__ == '__main__':
102+
unittest.main()

0 commit comments

Comments
 (0)