Skip to content

Commit b679b73

Browse files
feat: New text Generation Samples (GoogleCloudPlatform#12543)
* Added 2 Samples for Batch Predict
1 parent 6fa1fb5 commit b679b73

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from google.cloud.aiplatform import BatchPredictionJob
15+
16+
17+
def batch_code_prediction(
18+
input_uri: str = None, output_uri: str = None
19+
) -> BatchPredictionJob:
20+
"""Perform batch code prediction using a pre-trained code generation model.
21+
Args:
22+
input_uri (str, optional): URI of the input dataset. Could be a BigQuery table or a Google Cloud Storage file.
23+
E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
24+
output_uri (str, optional): URI where the output will be stored.
25+
Could be a BigQuery table or a Google Cloud Storage file.
26+
E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
27+
Returns:
28+
batch_prediction_job: The batch prediction job object containing details of the job.
29+
"""
30+
31+
# [START aiplatform_batch_code_predict]
32+
from vertexai.preview.language_models import CodeGenerationModel
33+
34+
# Example of using Google Cloud Storage bucket as the input and output data source
35+
# TODO (Developer): Replace the input_uri and output_uri with your own GCS paths
36+
# input_uri = "gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl"
37+
# output_uri = "gs://your-bucket-name/batch_code_predict_output"
38+
39+
code_model = CodeGenerationModel.from_pretrained("code-bison")
40+
41+
batch_prediction_job = code_model.batch_predict(
42+
dataset=input_uri,
43+
destination_uri_prefix=output_uri,
44+
# Optional:
45+
model_parameters={
46+
"maxOutputTokens": "200",
47+
"temperature": "0.2",
48+
},
49+
)
50+
print(batch_prediction_job.display_name)
51+
print(batch_prediction_job.resource_name)
52+
print(batch_prediction_job.state)
53+
54+
# [END aiplatform_batch_code_predict]
55+
56+
return batch_prediction_job
57+
58+
59+
if __name__ == "__main__":
60+
batch_code_prediction()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from google.cloud.aiplatform import BatchPredictionJob
15+
16+
17+
def batch_text_prediction(
18+
input_uri: str = None, output_uri: str = None
19+
) -> BatchPredictionJob:
20+
"""Perform batch text prediction using a pre-trained text generation model.
21+
Args:
22+
input_uri (str, optional): URI of the input dataset. Could be a BigQuery table or a Google Cloud Storage file.
23+
E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
24+
output_uri (str, optional): URI where the output will be stored.
25+
Could be a BigQuery table or a Google Cloud Storage file.
26+
E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
27+
Returns:
28+
batch_prediction_job: The batch prediction job object containing details of the job.
29+
"""
30+
31+
# [START aiplatform_batch_text_predict]
32+
from vertexai.preview.language_models import TextGenerationModel
33+
34+
# Example of using Google Cloud Storage bucket as the input and output data source
35+
# TODO (Developer): Replace the input_uri and output_uri with your own GCS paths
36+
# input_uri = "gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl"
37+
# output_uri = "gs://your-bucket-name/batch_text_predict_output"
38+
39+
# Initialize the text generation model from a pre-trained model named "text-bison"
40+
text_model = TextGenerationModel.from_pretrained("text-bison")
41+
42+
batch_prediction_job = text_model.batch_predict(
43+
dataset=input_uri,
44+
destination_uri_prefix=output_uri,
45+
# Optional:
46+
model_parameters={
47+
"maxOutputTokens": "200",
48+
"temperature": "0.2",
49+
"topP": "0.95",
50+
"topK": "40",
51+
},
52+
)
53+
print(batch_prediction_job.display_name)
54+
print(batch_prediction_job.resource_name)
55+
print(batch_prediction_job.state)
56+
57+
# [END aiplatform_batch_text_predict]
58+
return batch_prediction_job
59+
60+
61+
if __name__ == "__main__":
62+
batch_text_prediction()
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Callable
15+
16+
import batch_code_predict
17+
import batch_text_predict
18+
19+
from google.cloud import storage
20+
from google.cloud.aiplatform import BatchPredictionJob
21+
from google.cloud.aiplatform_v1 import JobState
22+
23+
import pytest
24+
25+
INPUT_BUCKET = "cloud-samples-data"
26+
OUTPUT_BUCKET = "python-docs-samples-tests"
27+
OUTPUT_PATH = "batch/batch_text_predict_output"
28+
29+
30+
def _clean_resources() -> None:
31+
storage_client = storage.Client()
32+
bucket = storage_client.get_bucket(OUTPUT_BUCKET)
33+
blobs = bucket.list_blobs(prefix=OUTPUT_PATH)
34+
for blob in blobs:
35+
blob.delete()
36+
37+
38+
@pytest.fixture(scope="session")
39+
def output_folder() -> str:
40+
yield f"gs://{OUTPUT_BUCKET}/{OUTPUT_PATH}"
41+
_clean_resources()
42+
43+
44+
def _main_test(test_func: Callable) -> BatchPredictionJob:
45+
job = None
46+
try:
47+
job = test_func()
48+
assert job.state == JobState.JOB_STATE_SUCCEEDED
49+
return job
50+
finally:
51+
if job is not None:
52+
job.delete()
53+
54+
55+
def test_batch_text_predict(output_folder: pytest.fixture()) -> None:
56+
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_text_predict.jsonl"
57+
job = _main_test(
58+
test_func=lambda: batch_text_predict.batch_text_prediction(
59+
input_uri, output_folder
60+
)
61+
)
62+
assert OUTPUT_PATH in job.output_info.gcs_output_directory
63+
64+
65+
def test_batch_code_predict(output_folder: pytest.fixture()) -> None:
66+
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_code_predict.jsonl"
67+
job = _main_test(
68+
test_func=lambda: batch_code_predict.batch_code_prediction(
69+
input_uri, output_folder
70+
)
71+
)
72+
assert OUTPUT_PATH in job.output_info.gcs_output_directory

0 commit comments

Comments
 (0)