From 630d6167ef87b68a21612add3804225d58156ab7 Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Mon, 20 Oct 2025 09:55:54 -0300 Subject: [PATCH 1/2] feat(genai): add batch prediction with bq sample --- .../BatchPredictionWithBq.java | 107 ++++++++++++++ .../batchprediction/BatchPredictionIT.java | 139 ++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java create mode 100644 genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java new file mode 100644 index 00000000000..85c1f5a3a2e --- /dev/null +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java @@ -0,0 +1,107 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.batchprediction; + +// [START googlegenaisdk_batchpredict_with_bq] + +import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED; +import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED; +import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED; +import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; + +import com.google.genai.Client; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobDestination; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.GetBatchJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.JobState; +import java.util.EnumSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +public class BatchPredictionWithBq { + + public static void main(String[] args) throws InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // To use a tuned model, set the model param to your tuned model using the following format: + // modelId = "projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID} + String modelId = "gemini-2.5-flash"; + String outputUri = "bq://your-project.your_dataset.your_table"; + createBatchJob(modelId, outputUri); + } + + // Creates a batch prediction job with Google BigQuery. + public static Optional createBatchJob(String modelId, String outputUri) + throws InterruptedException { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html + BatchJobSource batchJobSource = + BatchJobSource.builder() + .bigqueryUri("bq://storage-samples.generative_ai.batch_requests_for_multimodal_input") + .format("bigquery") + .build(); + + CreateBatchJobConfig batchJobConfig = + CreateBatchJobConfig.builder() + .displayName("your-display-name") + .dest(BatchJobDestination.builder().bigqueryUri(outputUri).format("bigquery").build()) + .build(); + + BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); + + String jobName = + batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); + System.out.println("Job name: " + jobName); + Optional jobState = batchJob.state(); + jobState.ifPresent(state -> System.out.println("Job state: " + state)); + // Job name: + // projects/.../locations/.../batchPredictionJobs/3189981423167602688 + // Job state: JOB_STATE_PENDING + + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html + Set completedStates = + EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); + + while (jobState.isPresent() && !completedStates.contains(jobState.get().knownEnum())) { + TimeUnit.SECONDS.sleep(30); + batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build()); + jobState = batchJob.state(); + jobState.ifPresent(state -> System.out.println("Job state: " + state)); + } + // Example response: + // Job state: JOB_STATE_QUEUED + // Job state: JOB_STATE_RUNNING + // Job state: JOB_STATE_RUNNING + // ... + // Job state: JOB_STATE_SUCCEEDED + return jobState; + } + } +} +// [END googlegenaisdk_batchpredict_with_bq] diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java new file mode 100644 index 00000000000..ec8a77a762b --- /dev/null +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.batchprediction; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static com.google.genai.types.JobState.Known.JOB_STATE_PENDING; +import static com.google.genai.types.JobState.Known.JOB_STATE_RUNNING; +import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_SELF; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.genai.Batches; +import com.google.genai.Client; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.GetBatchJobConfig; +import com.google.genai.types.JobState; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.lang.reflect.Field; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockedStatic; + +@RunWith(JUnit4.class) +public class BatchPredictionIT { + + private static final String GEMINI_FLASH = "gemini-2.5-flash"; + private static String jobName; + private static String outputBqUri; + private ByteArrayOutputStream bout; + private Batches mockedBatches; + private MockedStatic mockedStatic; + + // Check if the required environment variables are set. + public static void requireEnvVar(String envVarName) { + assertWithMessage(String.format("Missing environment variable '%s' ", envVarName)) + .that(System.getenv(envVarName)) + .isNotEmpty(); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + jobName = "projects/project_id/locations/us-central1/batchPredictionJobs/job_id"; + outputBqUri = "bq://your-project.your_dataset.your_table"; + } + + @Before + public void setUp() throws NoSuchFieldException, IllegalAccessException { + bout = new ByteArrayOutputStream(); + System.setOut(new PrintStream(bout)); + + // Arrange + Client.Builder mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); + mockedBatches = mock(Batches.class); + mockedStatic = mockStatic(Client.class); + mockedStatic.when(Client::builder).thenReturn(mockedBuilder); + Client mockedClient = mock(Client.class); + when(mockedBuilder.build()).thenReturn(mockedClient); + + // Using reflection because 'batches' is a final field and cannot be mocked directly. + // This is brittle but necessary for testing this class structure. + Field field = Client.class.getDeclaredField("batches"); + field.setAccessible(true); + field.set(mockedClient, mockedBatches); + + // Mock the sequence of job states to test the polling loop + BatchJob pendingJob = mock(BatchJob.class); + when(pendingJob.name()).thenReturn(Optional.of(jobName)); + when(pendingJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_PENDING))); + + BatchJob runningJob = mock(BatchJob.class); + when(runningJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_RUNNING))); + + BatchJob succeededJob = mock(BatchJob.class); + when(succeededJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED))); + + when(mockedBatches.create( + anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class))) + .thenReturn(pendingJob); + when(mockedBatches.get(anyString(), any(GetBatchJobConfig.class))) + .thenReturn(runningJob, succeededJob); + } + + @After + public void tearDown() { + System.setOut(null); + bout.reset(); + mockedStatic.close(); + } + + @Test + public void testBatchPredictionWithBq() throws InterruptedException { + // Act + Optional response = BatchPredictionWithBq.createBatchJob(GEMINI_FLASH, outputBqUri); + + // Assert + verify(mockedBatches, times(1)) + .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); + verify(mockedBatches, times(2)).get(anyString(), any(GetBatchJobConfig.class)); + + assertThat(response).isPresent(); + assertThat(response.get().knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED); + + String output = bout.toString(); + assertThat(output).contains("Job name: " + jobName); + assertThat(output).contains("Job state: JOB_STATE_PENDING"); + assertThat(output).contains("Job state: JOB_STATE_RUNNING"); + assertThat(output).contains("Job state: JOB_STATE_SUCCEEDED"); + } +} From b939b1bded73d8e2070f5cdecba7cdfde1afc8f2 Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Mon, 20 Oct 2025 10:51:36 -0300 Subject: [PATCH 2/2] refactor(genai): change polling logic and update tests --- .../batchprediction/BatchPredictionWithBq.java | 17 ++++++++++------- .../batchprediction/BatchPredictionIT.java | 8 ++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java index 85c1f5a3a2e..678f59d2fbf 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithBq.java @@ -32,7 +32,6 @@ import com.google.genai.types.HttpOptions; import com.google.genai.types.JobState; import java.util.EnumSet; -import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -48,7 +47,7 @@ public static void main(String[] args) throws InterruptedException { } // Creates a batch prediction job with Google BigQuery. - public static Optional createBatchJob(String modelId, String outputUri) + public static JobState createBatchJob(String modelId, String outputUri) throws InterruptedException { // Client Initialization. Once created, it can be reused for multiple requests. try (Client client = @@ -76,9 +75,10 @@ public static Optional createBatchJob(String modelId, String outputUri String jobName = batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); + JobState jobState = + batchJob.state().orElseThrow(() -> new IllegalStateException("Missing job state")); System.out.println("Job name: " + jobName); - Optional jobState = batchJob.state(); - jobState.ifPresent(state -> System.out.println("Job state: " + state)); + System.out.println("Job state: " + jobState); // Job name: // projects/.../locations/.../batchPredictionJobs/3189981423167602688 // Job state: JOB_STATE_PENDING @@ -88,11 +88,14 @@ public static Optional createBatchJob(String modelId, String outputUri Set completedStates = EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); - while (jobState.isPresent() && !completedStates.contains(jobState.get().knownEnum())) { + while (!completedStates.contains(jobState.knownEnum())) { TimeUnit.SECONDS.sleep(30); batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build()); - jobState = batchJob.state(); - jobState.ifPresent(state -> System.out.println("Job state: " + state)); + jobState = + batchJob + .state() + .orElseThrow(() -> new IllegalStateException("Missing job state during polling")); + System.out.println("Job state: " + jobState); } // Example response: // Job state: JOB_STATE_QUEUED diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java index ec8a77a762b..e20dc86af25 100644 --- a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -92,7 +92,7 @@ public void setUp() throws NoSuchFieldException, IllegalAccessException { field.setAccessible(true); field.set(mockedClient, mockedBatches); - // Mock the sequence of job states to test the polling loop + // Mock the sequence of job states to test the polling loop. BatchJob pendingJob = mock(BatchJob.class); when(pendingJob.name()).thenReturn(Optional.of(jobName)); when(pendingJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_PENDING))); @@ -120,15 +120,15 @@ public void tearDown() { @Test public void testBatchPredictionWithBq() throws InterruptedException { // Act - Optional response = BatchPredictionWithBq.createBatchJob(GEMINI_FLASH, outputBqUri); + JobState response = BatchPredictionWithBq.createBatchJob(GEMINI_FLASH, outputBqUri); // Assert verify(mockedBatches, times(1)) .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); verify(mockedBatches, times(2)).get(anyString(), any(GetBatchJobConfig.class)); - assertThat(response).isPresent(); - assertThat(response.get().knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED); + assertThat(response).isNotNull(); + assertThat(response.knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED); String output = bout.toString(); assertThat(output).contains("Job name: " + jobName);