Skip to content

Commit 511bb82

Browse files
No public description
PiperOrigin-RevId: 756905799
1 parent e57406b commit 511bb82

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for the pipeline."""
16+
17+
from collections.abc import Mapping, Sequence
18+
import csv
19+
import os
20+
from typing import TypedDict
21+
import natsort
22+
23+
24+
class ItemDict(TypedDict):
25+
id: int
26+
name: str
27+
supercategory: str
28+
29+
30+
def _read_csv_to_list(file_path: str) -> Sequence[str]:
31+
"""Reads a CSV file and returns its contents as a list.
32+
33+
This function reads the given CSV file, skips the header, and assumes
34+
there is only one column in the CSV. It returns the contents as a list of
35+
strings.
36+
37+
Args:
38+
file_path: The path to the CSV file.
39+
40+
Returns:
41+
The contents of the CSV file as a list of strings.
42+
"""
43+
data_list = []
44+
with open(file_path, 'r') as csvfile:
45+
reader = csv.reader(csvfile)
46+
for row in reader:
47+
data_list.append(row[0]) # Assuming there is only one column in the CSV
48+
return data_list
49+
50+
51+
def _categories_dictionary(objects: Sequence[str]) -> Mapping[int, ItemDict]:
52+
"""This function takes a list of objects and returns a dictionaries.
53+
54+
A dictionary of objects, where each object is represented by a dictionary
55+
with the following keys:
56+
- id: The ID of the object.
57+
- name: The name of the object.
58+
- supercategory: The supercategory of the object.
59+
60+
Args:
61+
objects: A list of strings, where each string is the name of an object.
62+
63+
Returns:
64+
A tuple of two dictionaries, as described above.
65+
"""
66+
category_index = {}
67+
68+
for num, obj_name in enumerate(objects, start=1):
69+
obj_dict = {'id': num, 'name': obj_name, 'supercategory': 'objects'}
70+
category_index[num] = obj_dict
71+
return category_index
72+
73+
74+
def load_labels(
75+
labels_path: str,
76+
) -> tuple[Sequence[str], Mapping[int, ItemDict]]:
77+
"""Loads labels from a CSV file and generates category mappings.
78+
79+
Args:
80+
labels_path: Path to the CSV file containing label definitions.
81+
82+
Returns:
83+
category_indices: A list of category indices.
84+
category_index: A dictionary mapping category indices to ItemDict objects.
85+
"""
86+
category_indices = _read_csv_to_list(labels_path)
87+
category_index = _categories_dictionary(category_indices)
88+
return category_indices, category_index
89+
90+
91+
def files_paths(folder_path):
92+
"""List the full paths of image files in a folder and sort them.
93+
94+
Args:
95+
folder_path: The path of the folder to list the image files from.
96+
97+
Returns:
98+
A list of full paths of the image files in the folder, sorted in ascending
99+
order.
100+
"""
101+
img_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp')
102+
image_files_full_path = []
103+
104+
for entry in os.scandir(folder_path):
105+
if entry.is_file() and entry.name.lower().endswith(img_extensions):
106+
image_files_full_path.append(entry.path)
107+
108+
# Sort the list of files by name
109+
image_files_full_path = natsort.natsorted(image_files_full_path)
110+
111+
return image_files_full_path
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import tempfile
17+
import unittest
18+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils
19+
20+
21+
class TestLoadLabels(unittest.TestCase):
22+
23+
def test_load_labels(self):
24+
# Create a temporary CSV file within the test
25+
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_csv:
26+
temp_csv.write('Label\nBottle\nCan\nCup\n')
27+
temp_csv_path = temp_csv.name
28+
29+
try:
30+
# Call the function under test
31+
category_indices, category_index = utils.load_labels(temp_csv_path)
32+
33+
# Expected results
34+
expected_list = ['Label', 'Bottle', 'Can', 'Cup']
35+
expected_dict = {
36+
1: {'id': 1, 'name': 'Label', 'supercategory': 'objects'},
37+
2: {'id': 2, 'name': 'Bottle', 'supercategory': 'objects'},
38+
3: {'id': 3, 'name': 'Can', 'supercategory': 'objects'},
39+
4: {'id': 4, 'name': 'Cup', 'supercategory': 'objects'},
40+
}
41+
42+
self.assertEqual(category_indices, expected_list)
43+
self.assertEqual(category_index, expected_dict)
44+
45+
finally:
46+
# Ensure the temporary file is deleted even if assertions fail
47+
os.remove(temp_csv_path)
48+
49+
def test_files_paths_with_images(self):
50+
# Create a temporary directory
51+
with tempfile.TemporaryDirectory() as temp_dir:
52+
# Create some image and non-image files
53+
filenames = ['img2.jpg', 'img1.png', 'doc1.txt', 'photo.gif']
54+
for filename in filenames:
55+
open(os.path.join(temp_dir, filename), 'a').close()
56+
57+
# Call the function under test
58+
result = utils.files_paths(temp_dir)
59+
60+
# Expected image files sorted naturally
61+
expected = [
62+
os.path.join(temp_dir, 'img1.png'),
63+
os.path.join(temp_dir, 'img2.jpg'),
64+
os.path.join(temp_dir, 'photo.gif'),
65+
]
66+
67+
self.assertEqual(result, expected)
68+
69+
def test_files_paths_with_no_images(self):
70+
with tempfile.TemporaryDirectory() as temp_dir:
71+
# Create only non-image files
72+
filenames = ['doc1.txt', 'readme.md']
73+
for filename in filenames:
74+
open(os.path.join(temp_dir, filename), 'a').close()
75+
76+
result = utils.files_paths(temp_dir)
77+
self.assertEqual(result, []) # Should return an empty list
78+
79+
def test_files_paths_empty_folder(self):
80+
with tempfile.TemporaryDirectory() as temp_dir:
81+
result = utils.files_paths(temp_dir)
82+
self.assertEqual(result, [])
83+
84+
85+
if __name__ == '__main__':
86+
unittest.main()

0 commit comments

Comments
 (0)