Skip to content

Commit e57406b

Browse files
No public description
PiperOrigin-RevId: 756846416
1 parent 8c52da8 commit e57406b

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client.biq_query_ops import _SCHEMA
17+
18+
19+
class TestSchemaDefinition(unittest.TestCase):
20+
21+
def test_schema_definition(self):
22+
expected_schema = [
23+
("particle", "INTEGER", "REQUIRED"),
24+
("source_name", "STRING", "REQUIRED"),
25+
("image_name", "STRING", "REQUIRED"),
26+
("detection_scores", "FLOAT", "REQUIRED"),
27+
("color", "STRING", "REQUIRED"),
28+
("creation_time", "STRING", "REQUIRED"),
29+
("detection_classes", "INTEGER", "REQUIRED"),
30+
("detection_classes_names", "STRING", "REQUIRED"),
31+
]
32+
33+
# Check schema length
34+
self.assertEqual(
35+
len(_SCHEMA), len(expected_schema), "Schema length mismatch."
36+
)
37+
38+
# Validate each field's name, type, and mode in order
39+
for idx, (field, expected) in enumerate(zip(_SCHEMA, expected_schema)):
40+
expected_name, expected_type, expected_mode = expected
41+
self.assertEqual(
42+
(field.name, field.field_type, field.mode),
43+
(expected_name, expected_type, expected_mode),
44+
f"Mismatch at field index {idx}.",
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Designed to interact with Google BigQuery.
16+
17+
For the purpose of dataset and table management, as well as data ingestion
18+
from pandas DataFrames.
19+
"""
20+
21+
import logging
22+
from google.cloud import bigquery
23+
from google.cloud import exceptions
24+
import pandas as pd
25+
import pandas_gbq
26+
27+
logging.basicConfig(
28+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
29+
)
30+
31+
_SCHEMA = [
32+
bigquery.SchemaField("particle", "INTEGER", mode="REQUIRED"),
33+
bigquery.SchemaField("source_name", "STRING", mode="REQUIRED"),
34+
bigquery.SchemaField("image_name", "STRING", mode="REQUIRED"),
35+
bigquery.SchemaField("detection_scores", "FLOAT", mode="REQUIRED"),
36+
bigquery.SchemaField("color", "STRING", mode="REQUIRED"),
37+
bigquery.SchemaField("creation_time", "STRING", mode="REQUIRED"),
38+
bigquery.SchemaField("detection_classes", "INTEGER", mode="REQUIRED"),
39+
bigquery.SchemaField("detection_classes_names", "STRING", mode="REQUIRED"),
40+
]
41+
42+
43+
def create_table(
44+
project_id: str,
45+
dataset_id: str,
46+
table_id: str,
47+
overwrite: bool = False, # New optional argument
48+
) -> None:
49+
"""Creates a table in a BigQuery dataset.
50+
51+
Args:
52+
project_id: The Google Cloud project ID.
53+
dataset_id: The ID of the dataset in which the table is to be created.
54+
table_id: The ID of the table to be created.
55+
overwrite: If True, deletes the preexisting table before creating a new
56+
one.
57+
"""
58+
client = bigquery.Client(project=project_id)
59+
dataset_ref = client.dataset(dataset_id)
60+
61+
try:
62+
# Check if the dataset already exists
63+
dataset = client.get_dataset(dataset_ref)
64+
except exceptions.NotFound:
65+
# If the dataset does not exist, create it
66+
dataset = bigquery.Dataset(dataset_ref)
67+
dataset = client.create_dataset(dataset)
68+
69+
table_ref = dataset.table(table_id)
70+
71+
try:
72+
# Check if the table already exists
73+
table = client.get_table(table_ref)
74+
if overwrite:
75+
logging.info(
76+
"Overwriting table '%s' in dataset '%s'...", table_id, dataset_id
77+
)
78+
client.delete_table(table_ref)
79+
table = bigquery.Table(table_ref, schema=_SCHEMA)
80+
client.create_table(table)
81+
print(f"Table '{table_id}' has been overwritten.")
82+
else:
83+
print(f"Table '{table_id}' already exists. Skipping creation.")
84+
except exceptions.NotFound:
85+
# If the table does not exist, create it
86+
table = bigquery.Table(table_ref, schema=_SCHEMA)
87+
client.create_table(table)
88+
print(f"Table '{table_id}' created successfully.")
89+
90+
91+
def ingest_data(
92+
df: pd.DataFrame, project_id: str, dataset_id: str, table_id: str
93+
) -> None:
94+
"""Ingests data from a pandas DataFrame into a specified BigQuery table.
95+
96+
This function takes a pandas DataFrame and appends its contents to a BigQuery
97+
table
98+
identified by the provided dataset and table IDs within the specified project.
99+
If the table does not exist, BigQuery automatically creates it with a schema
100+
inferred from the DataFrame.
101+
102+
Args:
103+
df: The pandas DataFrame containing the data to be ingested.
104+
project_id: The Google Cloud project ID.
105+
dataset_id: The ID of the dataset containing the target table.
106+
table_id: The ID of the table where the data will be ingested.
107+
"""
108+
table_ref = f"{project_id}.{dataset_id}.{table_id}"
109+
pandas_gbq.to_gbq(
110+
df, destination_table=table_ref, project_id=project_id, if_exists="append"
111+
)

0 commit comments

Comments
 (0)