Skip to content

Commit bf57460

Browse files
feat(genai): add Gemini batch prediction samples for BQ and GCS (GoogleCloudPlatform#12731)
* feat(genai): add Gemini batch prediction samples for BQ and GCS * cleaning up test file removing extra comments * updating import and removing sample data * move comments in region tag * removing white space * update values * update values * removing space * update values * update imprt * update projectid * updating testing file * update assert value * changing input value for gcs testing. Tests are failing due to permission erros * update output path
1 parent 362c9e7 commit bf57460

File tree

3 files changed

+100
-27
lines changed

3 files changed

+100
-27
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
17+
18+
output_uri = "bq://storage-samples.generative_ai.gen_ai_batch_prediction.predictions"
19+
20+
21+
def batch_predict_gemini_createjob(output_uri: str) -> str:
22+
"""Perform batch text prediction using a Gemini AI model and returns the output location"""
23+
24+
# [START generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
25+
import time
26+
import vertexai
27+
28+
from vertexai.batch_prediction import BatchPredictionJob
29+
30+
# TODO(developer): Update and un-comment below lines
31+
32+
# Initialize vertexai
33+
vertexai.init(project=PROJECT_ID, location="us-central1")
34+
35+
input_uri = "bq://storage-samples.generative_ai.batch_requests_for_multimodal_input"
36+
37+
# Submit a batch prediction job with Gemini model
38+
batch_prediction_job = BatchPredictionJob.submit(
39+
source_model="gemini-1.5-flash-002",
40+
input_dataset=input_uri,
41+
output_uri_prefix=output_uri,
42+
)
43+
44+
# Check job status
45+
print(f"Job resource name: {batch_prediction_job.resource_name}")
46+
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
47+
print(f"Job state: {batch_prediction_job.state.name}")
48+
49+
# Refresh the job until complete
50+
while not batch_prediction_job.has_ended:
51+
time.sleep(5)
52+
batch_prediction_job.refresh()
53+
54+
# Check if the job succeeds
55+
if batch_prediction_job.has_succeeded:
56+
print("Job succeeded!")
57+
else:
58+
print(f"Job failed: {batch_prediction_job.error}")
59+
60+
# Check the location of the output
61+
print(f"Job output location: {batch_prediction_job.output_location}")
62+
63+
# Example response:
64+
# Job output location: bq://Project-ID/gen-ai-batch-prediction/predictions-model-year-month-day-hour:minute:second.12345
65+
# [END generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
66+
return batch_prediction_job
67+
68+
69+
if __name__ == "__main__":
70+
batch_predict_gemini_createjob(output_uri)

generative_ai/batch_predict/gemini_batch_predict_gcs.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,14 @@
1313
# limitations under the License.
1414
import os
1515

16-
from vertexai.batch_prediction import BatchPredictionJob
1716

1817
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
1918

19+
output_uri = "gs://python-docs-samples-tests"
2020

21-
def batch_predict_gemini_createjob(
22-
input_uri: str, output_uri: str
23-
) -> BatchPredictionJob:
24-
"""Perform batch text prediction using a Gemini AI model.
25-
Args:
26-
input_uri (str): URI of the input file in Google Cloud Storage.
27-
Example: "gs://[BUCKET]/[DATASET].jsonl"
2821

29-
output_uri (str): URI of the output folder in Google Cloud Storage.
30-
Example: "gs://[BUCKET]/[OUTPUT].jsonl"
31-
Returns:
32-
batch_prediction_job: The batch prediction job object containing details of the job.
33-
"""
22+
def batch_predict_gemini_createjob(output_uri: str) -> str:
23+
"Perform batch text prediction using a Gemini AI model and returns the output location"
3424

3525
# [START generativeaionvertexai_batch_predict_gemini_createjob]
3626
import time
@@ -39,12 +29,12 @@ def batch_predict_gemini_createjob(
3929
from vertexai.batch_prediction import BatchPredictionJob
4030

4131
# TODO(developer): Update and un-comment below lines
42-
# input_uri ="gs://[BUCKET]/[OUTPUT].jsonl" # Example
43-
# output_uri ="gs://[BUCKET]"
4432

4533
# Initialize vertexai
4634
vertexai.init(project=PROJECT_ID, location="us-central1")
4735

36+
input_uri = "gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl"
37+
4838
# Submit a batch prediction job with Gemini model
4939
batch_prediction_job = BatchPredictionJob.submit(
5040
source_model="gemini-1.5-flash-002",
@@ -74,16 +64,9 @@ def batch_predict_gemini_createjob(
7464
# Example response:
7565
# Job output location: gs://your-bucket/gen-ai-batch-prediction/prediction-model-year-month-day-hour:minute:second.12345
7666

77-
# https://storage.googleapis.com/cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl
78-
79-
return batch_prediction_job
8067
# [END generativeaionvertexai_batch_predict_gemini_createjob]
68+
return batch_prediction_job
8169

8270

8371
if __name__ == "__main__":
84-
# TODO(developer): Update your Cloud Storage bucket and uri file paths
85-
GCS_BUCKET = "gs://your-bucket"
86-
batch_predict_gemini_createjob(
87-
input_uri=f"gs://{GCS_BUCKET}/batch_data/sample_input_file.jsonl",
88-
output_uri=f"gs://{GCS_BUCKET}/batch_predictions/sample_output/",
89-
)
72+
batch_predict_gemini_createjob(output_uri)

generative_ai/batch_predict/test_batch_predict_examples.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,32 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
1416
from typing import Callable
1517

18+
1619
from google.cloud import storage
1720
from google.cloud.aiplatform import BatchPredictionJob
1821
from google.cloud.aiplatform_v1 import JobState
1922

23+
2024
import pytest
2125

26+
2227
import batch_code_predict
2328
import batch_text_predict
29+
import gemini_batch_predict_bigquery
2430
import gemini_batch_predict_gcs
2531

32+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
33+
2634

2735
INPUT_BUCKET = "cloud-samples-data"
2836
OUTPUT_BUCKET = "python-docs-samples-tests"
2937
OUTPUT_PATH = "batch/batch_text_predict_output"
38+
GCS_OUTPUT_PATH = "gs://python-docs-samples-tests/"
39+
OUTPUT_TABLE = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions"
3040

3141

3242
def _clean_resources() -> None:
@@ -75,10 +85,20 @@ def test_batch_code_predict(output_folder: pytest.fixture()) -> None:
7585

7686

7787
def test_batch_gemini_predict_gcs(output_folder: pytest.fixture()) -> None:
78-
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_gemini_predict.jsonl"
88+
output_uri = "gs://python-docs-samples-tests"
7989
job = _main_test(
8090
test_func=lambda: gemini_batch_predict_gcs.batch_predict_gemini_createjob(
81-
input_uri, output_folder
91+
output_uri
8292
)
8393
)
84-
assert OUTPUT_PATH in job.output_location
94+
assert GCS_OUTPUT_PATH in job.output_location
95+
96+
97+
def test_batch_gemini_predict_bigquery(output_folder: pytest.fixture()) -> None:
98+
output_uri = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions"
99+
job = _main_test(
100+
test_func=lambda: gemini_batch_predict_bigquery.batch_predict_gemini_createjob(
101+
output_uri
102+
)
103+
)
104+
assert OUTPUT_TABLE in job.output_location

0 commit comments

Comments
 (0)